~azzar1/unity/add-show-desktop-key

« back to all changes in this revision

Viewing changes to www/apps/tutorial/test/TestFramework.py

  • Committer: mattgiuca
  • Date: 2008-01-24 23:57:26 UTC
  • Revision ID: svn-v3-trunk0:2b9c9e99-6f39-0410-b283-7f802c844ae2:trunk:294
Added application: tutorialservice. Will be used as the Ajax backend for
    tutorial (currently empty).
Moved tutorial/test to tutorialservice/test.
Reason: The testing framework will not be used by the tutorial HTML-side app.
    Only by the ajax backend.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# IVLE - Informatics Virtual Learning Environment
2
 
# Copyright (C) 2007-2008 The University of Melbourne
3
 
#
4
 
# This program is free software; you can redistribute it and/or modify
5
 
# it under the terms of the GNU General Public License as published by
6
 
# the Free Software Foundation; either version 2 of the License, or
7
 
# (at your option) any later version.
8
 
#
9
 
# This program is distributed in the hope that it will be useful,
10
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
 
# GNU General Public License for more details.
13
 
#
14
 
# You should have received a copy of the GNU General Public License
15
 
# along with this program; if not, write to the Free Software
16
 
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
17
 
 
18
 
# Module: TestFramework
19
 
# Author: Dilshan Angampitiya
20
 
# Date:   24/1/2008
21
 
 
22
 
# Brief description of the Module# define custom exceptions
23
 
# use exceptions for all errors found in testing
24
 
 
25
 
import sys, StringIO, copy
26
 
 
27
 
# student error
28
 
class FunctionNotFoundError(Exception):
29
 
    """This error is returned when a function was expected in student
30
 
    code but was not found"""
31
 
    def __init__(self, function_name):
32
 
        self.function_name = function_name
33
 
 
34
 
    def __str__(self):
35
 
        return "Function " + self.function_name + " not found"
36
 
 
37
 
# author error
38
 
class TestCreationError(Exception):
39
 
    """An error occured while creating the test suite or one of its components"""
40
 
    def __init__(self, reason):
41
 
        self._reason = reason
42
 
        
43
 
    def __str__(self):
44
 
        return self._reason
45
 
 
46
 
# author error
47
 
class SolutionError(Exception):
48
 
    """Error in the provided solution"""
49
 
    def __init__(self, exc_info):
50
 
        cla, exc, trbk = exc_info
51
 
        self.name = cla.__name__
52
 
        self._detail = str(exc)
53
 
 
54
 
    def __str__(self):
55
 
        return "Error running solution: %s" %str(self._detail)
56
 
 
57
 
# author error
58
 
class TestError(Exception):
59
 
    """Runtime error in the testing framework outside of the provided or student code"""
60
 
    def __init__(self, exc_info):
61
 
        cla, exc, trbk = exc_info
62
 
        self.name = cla.__name__
63
 
        self._detail = str(exc)
64
 
 
65
 
    def __str__(self):
66
 
        return "Error testing solution against attempt: %s" %str(self._detail)
67
 
 
68
 
# student error
69
 
class AttemptError(Exception):
70
 
    """Runtime error in the student code"""
71
 
    def __init__(self, exc_info):
72
 
        cla, exc, trbk = exc_info
73
 
        self._name = cla.__name__
74
 
        self._detail = str(exc)
75
 
 
76
 
    def is_critical(self):
77
 
        if (    self._name == 'FunctionNotFoundError'
78
 
            or  self._name == 'SyntaxError'
79
 
            or  self._name == 'IndentationError'):
80
 
            return True
81
 
        else:
82
 
            return False
83
 
 
84
 
    def to_dict(self):
85
 
        return {'name': self._name,
86
 
                'detail': self._detail,
87
 
                'critical': self.is_critical()
88
 
                }
89
 
 
90
 
    def __str__(self):
91
 
        return self._name + " - " + str(self._detail)
92
 
 
93
 
class TestCasePart:
94
 
    """
95
 
    A part of a test case which compares a subset of the input files or file streams.
96
 
    This can be done either with a comparision function, or by comparing directly, after
97
 
    applying normalisations.
98
 
    """
99
 
    # how to make this work? atm they seem to get passed the class as a first arg
100
 
    ident =lambda x: x
101
 
    ignore = lambda x: None
102
 
    match = lambda x,y: x==y
103
 
    always_match = lambda x,y: True
104
 
    true = lambda *x: True
105
 
    false = lambda *x: False
106
 
 
107
 
    def __init__(self, desc, default='match'):
108
 
        """Initialise with a description and a default behavior for output
109
 
        If default is match, unspecified files are matched exactly
110
 
        If default is ignore, unspecified files are ignored
111
 
        The default default is match.
112
 
        """
113
 
        self._desc = desc
114
 
        self._default = default
115
 
        if default == 'ignore':
116
 
            self._default_func = lambda *x: True
117
 
        else:
118
 
            self._default_func = lambda x,y: x==y
119
 
 
120
 
        self._file_tests = {}
121
 
        self._stdout_test = ('check', self._default_func)
122
 
        self._stderr_test = ('check', self._default_func)
123
 
        self._result_test = ('check', self._default_func)
124
 
 
125
 
    def get_description(self):
126
 
        "Getter for description"
127
 
        return self._desc
128
 
 
129
 
    def _set_default_function(self, function, test_type):
130
 
        """"Ensure test type is valid and set function to a default
131
 
        if not specified"""
132
 
        
133
 
        if test_type not in ['norm', 'check']:
134
 
            raise TestCreationError("Invalid test type in %s" %self._desc)
135
 
        
136
 
        if function == '':
137
 
            if test_type == 'norm': function = lambda x: x
138
 
            else: function = lambda x,y: x==y
139
 
 
140
 
        return function
141
 
 
142
 
    def _validate_function(self, function, included_code):
143
 
        """Create a function object from the given string.
144
 
        If a valid function object cannot be created, raise and error.
145
 
        """
146
 
        if not callable(function):
147
 
            try:
148
 
                exec "__f__ = %s" %function in included_code
149
 
            except:
150
 
                raise TestCreationError("Invalid function %s" %function)
151
 
 
152
 
            f = included_code['__f__']
153
 
 
154
 
            if not callable(f):
155
 
                raise TestCreationError("Invalid function %s" %function)    
156
 
        else:
157
 
            f = function
158
 
 
159
 
        return f
160
 
 
161
 
    def validate_functions(self, included_code):
162
 
        """Ensure all functions used by the test cases exist and are callable.
163
 
        Also covert their string representations to function objects.
164
 
        This can only be done once all the include code has been specified.
165
 
        """
166
 
        (test_type, function) = self._stdout_test
167
 
        self._stdout_test = (test_type, self._validate_function(function, included_code))
168
 
        
169
 
        (test_type, function) = self._stderr_test
170
 
        self._stderr_test = (test_type, self._validate_function(function, included_code))
171
 
 
172
 
        for filename, (test_type, function) in self._file_tests.items():
173
 
            self._file_tests[filename] = (test_type, self._validate_function(function, included_code))
174
 
            
175
 
    def add_result_test(self, function, test_type='norm'):
176
 
        "Test part that compares function return values"
177
 
        function = self._set_default_function(function, test_type)
178
 
        self._result_test = (test_type, function)
179
 
 
180
 
            
181
 
    def add_stdout_test(self, function, test_type='norm'):
182
 
        "Test part that compares stdout"
183
 
        function = self._set_default_function(function, test_type)
184
 
        self._stdout_test = (test_type, function)
185
 
        
186
 
 
187
 
    def add_stderr_test(self, function, test_type='norm'):
188
 
        "Test part that compares stderr"
189
 
        function = self._set_default_function(function, test_type)
190
 
        self._stderr_test = (test_type, function)
191
 
 
192
 
    def add_file_test(self, filename, function, test_type='norm'):
193
 
        "Test part that compares the contents of a specified file"
194
 
        function = self._set_default_function(function, test_type)
195
 
        self._file_tests[filename] = (test_type, function)
196
 
 
197
 
    def _check_output(self, solution_output, attempt_output, test_type, f):
198
 
        """Compare solution output and attempt output using the
199
 
        specified comparision function.
200
 
        """
201
 
        # converts unicode to string
202
 
        if type(solution_output) == unicode:    
203
 
            solution_output = str(solution_output)
204
 
        if type(attempt_output) == unicode:
205
 
            attempt_output = str(attempt_output)
206
 
            
207
 
        if test_type == 'norm':
208
 
            return f(solution_output) == f(attempt_output)
209
 
        else:
210
 
            return f(solution_output, attempt_output)
211
 
 
212
 
    def run(self, solution_data, attempt_data):
213
 
        """Run the tests to compare the solution and attempt data
214
 
        Returns the empty string is the test passes, or else an error message.
215
 
        """
216
 
 
217
 
        # check function return value (None for scripts)
218
 
        (test_type, f) = self._result_test
219
 
        if not self._check_output(solution_data['result'], attempt_data['result'], test_type, f):       
220
 
            return 'function return value does not match'
221
 
 
222
 
        # check stdout
223
 
        (test_type, f) = self._stdout_test
224
 
        if not self._check_output(solution_data['stdout'], attempt_data['stdout'], test_type, f):       
225
 
            return 'stdout does not match'
226
 
 
227
 
        #check stderr
228
 
        (test_type, f) = self._stderr_test
229
 
        if not self._check_output(solution_data['stderr'], attempt_data['stderr'], test_type, f):        
230
 
            return 'stderr does not match'
231
 
 
232
 
 
233
 
        solution_files = solution_data['modified_files']
234
 
        attempt_files = attempt_data['modified_files']
235
 
 
236
 
        # check files indicated by test
237
 
        for (filename, (test_type, f)) in self._file_tests.items():
238
 
            if filename not in solution_files:
239
 
                raise SolutionError('File %s not found' %filename)
240
 
            elif filename not in attempt_files:
241
 
                return filename + ' not found'
242
 
            elif not self._check_output(solution_files[filename], attempt_files[filename], test_type, f):
243
 
                return filename + ' does not match'
244
 
 
245
 
        if self._default == 'ignore':
246
 
            return ''
247
 
 
248
 
        # check files found in solution, but not indicated by test
249
 
        for filename in [f for f in solution_files if f not in self._file_tests]:
250
 
            if filename not in attempt_files:
251
 
                return filename + ' not found'
252
 
            elif not self._check_output(solution_files[filename], attempt_files[filename], 'match', lambda x,y: x==y):
253
 
                return filename + ' does not match'
254
 
 
255
 
        # check if attempt has any extra files
256
 
        for filename in [f for f in attempt_files if f not in solution_files]:
257
 
            return "Unexpected file found: " + filename
258
 
 
259
 
        # Everything passed with no problems
260
 
        return ''
261
 
        
262
 
class TestCase:
263
 
    """
264
 
    A set of tests with a common inputs
265
 
    """
266
 
    def __init__(self, name='', function=None, stdin='', filespace=None, global_space=None):
267
 
        """Initialise with name and optionally, a function to test (instead of the entire script)
268
 
        The inputs stdin, the filespace and global variables can also be specified at
269
 
        initialisation, but may also be set later.
270
 
        """
271
 
        if global_space == None:
272
 
            global_space = {}
273
 
        if filespace == None:
274
 
            filespace = {}
275
 
        
276
 
        self._name = name
277
 
        
278
 
        if function == '': function = None
279
 
        self._function = function
280
 
        self._list_args = []
281
 
        self._keyword_args = {}
282
 
        
283
 
        # stdin must have a newline at the end for raw_input to work properly
284
 
        if stdin[-1:] != '\n': stdin += '\n'
285
 
        
286
 
        self._stdin = stdin
287
 
        self._filespace = TestFilespace(filespace)
288
 
        self._global_space = global_space
289
 
        self._parts = []
290
 
 
291
 
    def set_stdin(self, stdin):
292
 
        """ Set the given string as the stdin for this test case"""
293
 
        self._stdin = stdin
294
 
 
295
 
    def add_file(self, filename, data):
296
 
        """ Insert the given filename-data pair into the filespace for this test case"""
297
 
        self._filespace.add_file(filename, data)
298
 
        
299
 
    def add_variable(self, variable, value):
300
 
        """ Add the given varibale-value pair to the initial global environment
301
 
        for this test case.
302
 
        Throw and exception if thevalue cannot be paresed.
303
 
        """
304
 
        
305
 
        try:
306
 
            self._global_space[variable] = eval(value)
307
 
        except:
308
 
            raise TestCreationError("Invalid value for variable %s: %s" %(variable, value))
309
 
 
310
 
    def add_arg(self, value, name=None):
311
 
        """ Add a value to the argument list. This only applies when testing functions.
312
 
        By default arguments are not named, but if they are, they become keyword arguments.
313
 
        """
314
 
        try:
315
 
            if name == None or name == '':
316
 
                self._list_args.append(eval(value))
317
 
            else:
318
 
                self._keyword_args[name] = value
319
 
        except:
320
 
            raise TestCreationError("Invalid value for function argument: %s" %value)
321
 
        
322
 
    def add_part(self, test_part):
323
 
        """ Add a TestPart to this test case"""
324
 
        self._parts.append(test_part)
325
 
 
326
 
    def validate_functions(self, included_code):
327
 
        """ Validate all the functions in each part in this test case
328
 
        This can only be done once all the include code has been specified.
329
 
        """
330
 
        for part in self._parts:
331
 
            part.validate_functions(included_code)
332
 
 
333
 
    def get_name(self):
334
 
        """ Get the name of the test case """
335
 
        return self._name
336
 
 
337
 
    def run(self, solution, attempt_file):
338
 
        """ Run the solution and the attempt with the inputs specified for this test case.
339
 
        Then pass the outputs to each test part and collate the results.
340
 
        """
341
 
        case_dict = {}
342
 
        case_dict['name'] = self._name
343
 
        
344
 
        # Run solution
345
 
        try:
346
 
            global_space_copy = copy.deepcopy(self._global_space)
347
 
            solution_data = self._execstring(solution, global_space_copy)
348
 
            
349
 
            # if we are just testing a function
350
 
            if not self._function == None:
351
 
                if self._function not in global_space_copy:
352
 
                    raise FunctionNotFoundError(self._function)
353
 
                solution_data = self._run_function(lambda: global_space_copy[self._function](*self._list_args, **self._keyword_args))
354
 
                
355
 
        except:
356
 
            raise SolutionError(sys.exc_info())
357
 
 
358
 
        # Run student attempt
359
 
        try:
360
 
            global_space_copy = copy.deepcopy(self._global_space)
361
 
            attempt_data = self._execfile(attempt_file, global_space_copy)
362
 
            
363
 
            # if we are just testing a function
364
 
            if not self._function == None:
365
 
                if self._function not in global_space_copy:
366
 
                    raise FunctionNotFoundError(self._function)
367
 
                attempt_data = self._run_function(lambda: global_space_copy[self._function](*self._list_args, **self._keyword_args))
368
 
        except:
369
 
            case_dict['exception'] = AttemptError(sys.exc_info()).to_dict()
370
 
            return case_dict
371
 
        
372
 
        results = []
373
 
 
374
 
        # generate results
375
 
        for test_part in self._parts:
376
 
            result = test_part.run(solution_data, attempt_data)
377
 
            result_dict = {}
378
 
            result_dict['description'] = test_part.get_description()
379
 
            result_dict['passed']  = (result == '')
380
 
            if result_dict['passed'] == False:
381
 
                result_dict['error_message'] = result
382
 
                
383
 
            results.append(result_dict)
384
 
 
385
 
        case_dict['parts'] = results
386
 
 
387
 
        return case_dict
388
 
                
389
 
    def _execfile(self, filename, global_space):
390
 
        """ Execute the file given by 'filename' in global_space, and return the outputs. """
391
 
        self._initialise_global_space(global_space)
392
 
        data = self._run_function(lambda: execfile(filename, global_space))
393
 
        return data
394
 
 
395
 
    def _execstring(self, string, global_space):
396
 
        """ Execute the given string in global_space, and return the outputs. """
397
 
        self._initialise_global_space(global_space)
398
 
        # _run_function handles tuples in a special way
399
 
        data = self._run_function((string, global_space))
400
 
        return data
401
 
 
402
 
    def _initialise_global_space(self, global_space):
403
 
        """ Modify the provided global_space so that file, open and raw_input are redefined
404
 
        to use our methods instead.
405
 
        """
406
 
        self._current_filespace_copy = self._filespace.copy()
407
 
        global_space['file'] = lambda filename, mode='r', bufsize=-1: self._current_filespace_copy.openfile(filename, mode)
408
 
        global_space['open'] = global_space['file']
409
 
        global_space['raw_input'] = lambda x=None: raw_input()
410
 
        return global_space
411
 
 
412
 
    def _run_function(self, function):
413
 
        """ Run the provided function with the provided stdin, capturing stdout and stderr
414
 
        and the return value.
415
 
        Return all the output data.
416
 
        """
417
 
        import sys, StringIO
418
 
        sys_stdout, sys_stdin, sys_stderr = sys.stdout, sys.stdin, sys.stderr
419
 
 
420
 
        output_stream, input_stream, error_stream = StringIO.StringIO(), StringIO.StringIO(self._stdin), StringIO.StringIO()
421
 
        sys.stdout, sys.stdin, sys.stderr = output_stream, input_stream, error_stream
422
 
 
423
 
        try:
424
 
            if type(function) == tuple:
425
 
                # very hackish... exec can't be put into a lambda function!
426
 
                # or even with eval
427
 
                exec(function[0], function[1])
428
 
                result = None
429
 
            else:
430
 
                result = function()
431
 
        except:
432
 
            sys.stdout, sys.stdin, sys.stderr = sys_stdout, sys_stdin, sys_stderr
433
 
            raise
434
 
        
435
 
        sys.stdout, sys.stdin, sys.stderr = sys_stdout, sys_stdin, sys_stderr
436
 
 
437
 
        self._current_filespace_copy.flush_all()
438
 
            
439
 
        return {'result': result,
440
 
                'stdout': output_stream.getvalue(),
441
 
                'stderr': output_stream.getvalue(),
442
 
                'modified_files': self._current_filespace_copy.get_modified_files()}
443
 
 
444
 
class TestSuite:
445
 
    """
446
 
    The complete collection of test cases for a given problem
447
 
    """
448
 
    def __init__(self, name, solution=None):
449
 
        """Initialise with the name of the test suite (the problem name) and the solution.
450
 
        The solution may be specified later.
451
 
        """
452
 
        self._solution = solution
453
 
        self._name = name
454
 
        self._tests = []
455
 
        self.add_include_code("")
456
 
 
457
 
    def add_solution(self, solution):
458
 
        " Specifiy the solution script for this problem "
459
 
        self._solution = solution
460
 
 
461
 
    def has_solution(self):
462
 
        " Returns true if a soltion has been provided "
463
 
        return self._solution != None
464
 
 
465
 
    def add_include_code(self, include_code = ''):
466
 
        """ Add include code that may be used by the test cases during
467
 
        comparison of outputs.
468
 
        """
469
 
        
470
 
        # if empty, make sure it can still be executed
471
 
        if include_code == "":
472
 
            include_code = "pass"
473
 
        self._include_code = str(include_code)
474
 
        
475
 
        include_space = {}
476
 
        try:
477
 
            exec self._include_code in include_space
478
 
        except:
479
 
            raise TestCreationError("Bad include code")
480
 
 
481
 
        self._include_space = include_space
482
 
    
483
 
    def add_case(self, test_case):
484
 
        """ Add a TestCase, then validate all functions inside test case
485
 
        now that the include code is known
486
 
        """
487
 
        self._tests.append(test_case)
488
 
        test_case.validate_functions(self._include_space)
489
 
 
490
 
    def run_tests(self, attempt_file):
491
 
        " Run all test cases and collate the results "
492
 
        
493
 
        problem_dict = {}
494
 
        problem_dict['name'] = self._name
495
 
        
496
 
        test_case_results = []
497
 
        for test in self._tests:
498
 
            result_dict = test.run(self._solution, attempt_file)
499
 
            if 'exception' in result_dict and result_dict['exception']['critical']:
500
 
                # critical error occured, running more cases is useless
501
 
                # FunctionNotFound, Syntax, Indentation
502
 
                problem_dict['critical_error'] = result_dict['exception']
503
 
                return problem_dict
504
 
            
505
 
            test_case_results.append(result_dict)
506
 
 
507
 
        problem_dict['cases'] = test_case_results
508
 
        return problem_dict
509
 
 
510
 
    def get_name(self):
511
 
        return self._name
512
 
 
513
 
class TestFilespace:
514
 
    """
515
 
    Our dummy file system which is accessed by code being tested.
516
 
    Implemented as a dictionary which maps filenames to strings
517
 
    """
518
 
    def __init__(self, files=None):
519
 
        "Initialise, optionally with filename-filedata pairs"
520
 
 
521
 
        if files == None:
522
 
            files = {}
523
 
 
524
 
        # dict mapping files to strings
525
 
        self._files = {}
526
 
        self._files.update(files)
527
 
        # set of file names
528
 
        self._modified_files = set([])
529
 
        # dict mapping files to stringIO objects
530
 
        self._open_files = {}
531
 
 
532
 
    def add_file(self, filename, data):
533
 
        " Add a file to the filespace "
534
 
        self._files[filename] = data
535
 
 
536
 
    def openfile(self, filename, mode='r'):
537
 
        """ Open a file from the filespace with the given mode.
538
 
        Return a StringIO subclass object with the file contents.
539
 
        """
540
 
        import re
541
 
 
542
 
        if filename in self._open_files:
543
 
            raise IOError("File already open: %s" %filename)
544
 
 
545
 
        if not re.compile("[rwa][+b]{0,2}").match(mode):
546
 
            raise IOError("invalid mode %s" %mode)
547
 
        
548
 
        ## TODO: validate filename?
549
 
        
550
 
        mode.replace("b",'')
551
 
        
552
 
        # initialise the file properly (truncate/create if required)
553
 
        if mode[0] == 'w':
554
 
            self._files[filename] = ''
555
 
            self._modified_files.add(filename)
556
 
        elif filename not in self._files:
557
 
            if mode[0] == 'a':
558
 
                self._files[filename] = ''
559
 
                self._modified_files.add(filename)
560
 
            else:
561
 
                raise IOError(2, "Access to file denied: %s" %filename)
562
 
 
563
 
        # for append mode, remember the existing data
564
 
        if mode[0] == 'a':
565
 
            existing_data = self._files[filename]
566
 
        else:
567
 
            existing_data = ""
568
 
 
569
 
        # determine what operations are allowed
570
 
        reading_ok = (len(mode) == 2 or mode[0] == 'r')
571
 
        writing_ok = (len(mode) == 2 or mode[0] in 'wa')
572
 
 
573
 
        # for all writing modes, start off with blank file
574
 
        if mode[0] == 'w':
575
 
            initial_data = ''
576
 
        else:
577
 
            initial_data = self._files[filename]
578
 
 
579
 
        file_object = TestStringIO(initial_data, filename, self, reading_ok, writing_ok, existing_data)
580
 
        self._open_files[filename] = file_object
581
 
        
582
 
        return file_object
583
 
 
584
 
    def flush_all(self):
585
 
        """ Flush all open files
586
 
        """
587
 
        for file_object in self._open_files.values():
588
 
            file_object.flush()
589
 
 
590
 
    def updatefile(self,filename, data):
591
 
        """ Callback function used by an open file to inform when it has been updated.
592
 
        """
593
 
        if filename in self._open_files:
594
 
            self._files[filename] = data
595
 
            if self._open_files[filename].is_modified():
596
 
                self._modified_files.add(filename)
597
 
        else:
598
 
            raise IOError(2, "Access to file denied: %s" %filename)
599
 
 
600
 
    def closefile(self, filename):
601
 
        """ Callback function used by an open file to inform when it has been closed.
602
 
        """
603
 
        if filename in self._open_files:
604
 
            del self._open_files[filename]
605
 
 
606
 
    def get_modified_files(self):
607
 
        """" A subset of the filespace containing only those files which have been
608
 
        modified
609
 
        """
610
 
        modified_files = {}
611
 
        for filename in self._modified_files:
612
 
            modified_files[filename] = self._files[filename]
613
 
 
614
 
        return modified_files
615
 
 
616
 
    def get_open_files(self):
617
 
        " Return the names of all open files "
618
 
        return self._open_files.keys()
619
 
            
620
 
    def copy(self):
621
 
        """ Return a copy of the current filespace.
622
 
        Only the files are copied, not the modified or open file lists.
623
 
        """
624
 
        self.flush_all()
625
 
        return TestFilespace(self._files)
626
 
 
627
 
class TestStringIO(StringIO.StringIO):
628
 
    """
629
 
    A subclass of StringIO which acts as a file in our dummy file system
630
 
    """
631
 
    def __init__(self, string, filename, filespace, reading_ok, writing_ok, existing_data):
632
 
        """ Initialise with the filedata, file name and infomation on what ops are
633
 
        acceptable """
634
 
        StringIO.StringIO.__init__(self, string)
635
 
        self._filename = filename
636
 
        self._filespace = filespace
637
 
        self._reading_ok = reading_ok
638
 
        self._writing_ok = writing_ok
639
 
        self._existing_data = existing_data
640
 
        self._modified = False
641
 
        self._open = True
642
 
 
643
 
    # Override all standard file ops. Make sure that they are valid with the given
644
 
    # permissions and if so then call the corresponding method in StringIO
645
 
    
646
 
    def read(self, *args):
647
 
        if not self._reading_ok:
648
 
            raise IOError(9, "Bad file descriptor")
649
 
        else:
650
 
            return StringIO.StringIO.read(self, *args)
651
 
 
652
 
    def readline(self, *args):
653
 
        if not self._reading_ok:
654
 
            raise IOError(9, "Bad file descriptor")
655
 
        else:
656
 
            return StringIO.StringIO.readline(self, *args)
657
 
 
658
 
    def readlines(self, *args):
659
 
        if not self._reading_ok:
660
 
            raise IOError(9, "Bad file descriptor")
661
 
        else:
662
 
            return StringIO.StringIO.readlines(self, *args)
663
 
 
664
 
    def seek(self, *args):
665
 
        if not self._reading_ok:
666
 
            raise IOError(9, "Bad file descriptor")
667
 
        else:
668
 
            return StringIO.StringIO.seek(self, *args)
669
 
 
670
 
    def truncate(self, *args):
671
 
        self._modified = True
672
 
        if not self._writing_ok:
673
 
            raise IOError(9, "Bad file descriptor")
674
 
        else:
675
 
            return StringIO.StringIO.truncate(self, *args)
676
 
        
677
 
    def write(self, *args):
678
 
        self._modified = True
679
 
        if not self._writing_ok:
680
 
            raise IOError(9, "Bad file descriptor")
681
 
        else:
682
 
            return StringIO.StringIO.write(self, *args)
683
 
 
684
 
    def writelines(self, *args):
685
 
        self._modified = True
686
 
        if not self._writing_ok:
687
 
            raise IOError(9, "Bad file descriptor")
688
 
        else:
689
 
            return StringIO.StringIO.writelines(self, *args)
690
 
 
691
 
    def is_modified(self):
692
 
        " Return true if the file has been written to, or truncated"
693
 
        return self._modified
694
 
        
695
 
    def flush(self):
696
 
        " Update the contents of the filespace with the new data "
697
 
        self._filespace.updatefile(self._filename, self._existing_data+self.getvalue())
698
 
        return StringIO.StringIO.flush(self)
699
 
 
700
 
    def close(self):
701
 
        " Flush the file and close it "
702
 
        self.flush()
703
 
        self._filespace.closefile(self._filename)
704
 
        return StringIO.StringIO.close(self)
705
 
 
706
 
##def get_function(filename, function_name):
707
 
##      import compiler
708
 
##      mod = compiler.parseFile(filename)
709
 
##      for node in mod.node.nodes:
710
 
##              if isinstance(node, compiler.ast.Function) and node.name == function_name:
711
 
##                      return node
712
 
##