diff --git a/CMakeLists.txt b/CMakeLists.txt index a52cf47ec..360907867 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -227,4 +227,6 @@ add_custom_target(memcheck valgrind --leak-check=full --show-reachable=yes ${PRO DEPENDS storm) add_custom_target(memcheck-tests valgrind --leak-check=full --show-reachable=yes ${PROJECT_BINARY_DIR}/storm-tests DEPENDS storm-tests) - + +set (CPPLINT_ARGS --filter=-whitespace/tab,-whitespace/line_length,-legal/copyright,-readability/streams) +add_custom_target(style python cpplint.py ${CPPLINT_ARGS} `find ./src/ -iname "*.h" -or -iname "*.cpp"`) diff --git a/cpplint.py b/cpplint.py new file mode 100644 index 000000000..526b9556d --- /dev/null +++ b/cpplint.py @@ -0,0 +1,3361 @@ +#!/usr/bin/python +# +# Copyright (c) 2009 Google Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Here are some issues that I've had people identify in my code during reviews, +# that I think are possible to flag automatically in a lint tool. If these were +# caught by lint, it would save time both for myself and that of my reviewers. +# Most likely, some of these are beyond the scope of the current lint framework, +# but I think it is valuable to retain these wish-list items even if they cannot +# be immediately implemented. +# +# Suggestions +# ----------- +# - Check for no 'explicit' for multi-arg ctor +# - Check for boolean assign RHS in parens +# - Check for ctor initializer-list colon position and spacing +# - Check that if there's a ctor, there should be a dtor +# - Check accessors that return non-pointer member variables are +# declared const +# - Check accessors that return non-const pointer member vars are +# *not* declared const +# - Check for using public includes for testing +# - Check for spaces between brackets in one-line inline method +# - Check for no assert() +# - Check for spaces surrounding operators +# - Check for 0 in pointer context (should be NULL) +# - Check for 0 in char context (should be '\0') +# - Check for camel-case method name conventions for methods +# that are not simple inline getters and setters +# - Check that base classes have virtual destructors +# put " // namespace" after } that closes a namespace, with +# namespace's name after 'namespace' if it is named. +# - Do not indent namespace contents +# - Avoid inlining non-trivial constructors in header files +# include base/basictypes.h if DISALLOW_EVIL_CONSTRUCTORS is used +# - Check for old-school (void) cast for call-sites of functions +# ignored return value +# - Check gUnit usage of anonymous namespace +# - Check for class declaration order (typedefs, consts, enums, +# ctor(s?), dtor, friend declarations, methods, member vars) +# + +"""Does google-lint on c++ files. + +The goal of this script is to identify places in the code that *may* +be in non-compliance with google style. It does not attempt to fix +up these problems -- the point is to educate. It does also not +attempt to find all problems, or to ensure that everything it does +find is legitimately a problem. + +In particular, we can get very confused by /* and // inside strings! +We do a small hack, which is to ignore //'s with "'s after them on the +same line, but it is far from perfect (in either direction). +""" + +import codecs +import getopt +import math # for log +import os +import re +import sre_compile +import string +import sys +import unicodedata + + +_USAGE = """ +Syntax: cpplint.py [--verbose=#] [--output=vs7] [--filter=-x,+y,...] + [--counting=total|toplevel|detailed] + [file] ... + + The style guidelines this tries to follow are those in + http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml + + Every problem is given a confidence score from 1-5, with 5 meaning we are + certain of the problem, and 1 meaning it could be a legitimate construct. + This will miss some errors, and is not a substitute for a code review. + + To suppress false-positive errors of a certain category, add a + 'NOLINT(category)' comment to the line. NOLINT or NOLINT(*) + suppresses errors of all categories on that line. + + The files passed in will be linted; at least one file must be provided. + Linted extensions are .cc, .cpp, and .h. Other file types will be ignored. + + Flags: + + output=vs7 + By default, the output is formatted to ease emacs parsing. Visual Studio + compatible output (vs7) may also be used. Other formats are unsupported. + + verbose=# + Specify a number 0-5 to restrict errors to certain verbosity levels. + + filter=-x,+y,... + Specify a comma-separated list of category-filters to apply: only + error messages whose category names pass the filters will be printed. + (Category names are printed with the message and look like + "[whitespace/indent]".) Filters are evaluated left to right. + "-FOO" and "FOO" means "do not print categories that start with FOO". + "+FOO" means "do print categories that start with FOO". + + Examples: --filter=-whitespace,+whitespace/braces + --filter=whitespace,runtime/printf,+runtime/printf_format + --filter=-,+build/include_what_you_use + + To see a list of all the categories used in cpplint, pass no arg: + --filter= + + counting=total|toplevel|detailed + The total number of errors found is always printed. If + 'toplevel' is provided, then the count of errors in each of + the top-level categories like 'build' and 'whitespace' will + also be printed. If 'detailed' is provided, then a count + is provided for each category like 'build/class'. +""" + +# We categorize each error message we print. Here are the categories. +# We want an explicit list so we can list them all in cpplint --filter=. +# If you add a new error message with a new category, add it to the list +# here! cpplint_unittest.py should tell you if you forget to do this. +# \ used for clearer layout -- pylint: disable-msg=C6013 +_ERROR_CATEGORIES = [ + 'build/class', + 'build/deprecated', + 'build/endif_comment', + 'build/explicit_make_pair', + 'build/forward_decl', + 'build/header_guard', + 'build/include', + 'build/include_alpha', + 'build/include_order', + 'build/include_what_you_use', + 'build/namespaces', + 'build/printf_format', + 'build/storage_class', + 'legal/copyright', + 'readability/braces', + 'readability/casting', + 'readability/check', + 'readability/constructors', + 'readability/fn_size', + 'readability/function', + 'readability/multiline_comment', + 'readability/multiline_string', + 'readability/nolint', + 'readability/streams', + 'readability/todo', + 'readability/utf8', + 'runtime/arrays', + 'runtime/casting', + 'runtime/explicit', + 'runtime/int', + 'runtime/init', + 'runtime/invalid_increment', + 'runtime/member_string_references', + 'runtime/memset', + 'runtime/operator', + 'runtime/printf', + 'runtime/printf_format', + 'runtime/references', + 'runtime/rtti', + 'runtime/sizeof', + 'runtime/string', + 'runtime/threadsafe_fn', + 'runtime/virtual', + 'whitespace/blank_line', + 'whitespace/braces', + 'whitespace/comma', + 'whitespace/comments', + 'whitespace/end_of_line', + 'whitespace/ending_newline', + 'whitespace/indent', + 'whitespace/labels', + 'whitespace/line_length', + 'whitespace/newline', + 'whitespace/operators', + 'whitespace/parens', + 'whitespace/semicolon', + 'whitespace/tab', + 'whitespace/todo' + ] + +# The default state of the category filter. This is overrided by the --filter= +# flag. By default all errors are on, so only add here categories that should be +# off by default (i.e., categories that must be enabled by the --filter= flags). +# All entries here should start with a '-' or '+', as in the --filter= flag. +_DEFAULT_FILTERS = ['-build/include_alpha'] + +# We used to check for high-bit characters, but after much discussion we +# decided those were OK, as long as they were in UTF-8 and didn't represent +# hard-coded international strings, which belong in a separate i18n file. + +# Headers that we consider STL headers. +_STL_HEADERS = frozenset([ + 'algobase.h', 'algorithm', 'alloc.h', 'bitset', 'deque', 'exception', + 'function.h', 'functional', 'hash_map', 'hash_map.h', 'hash_set', + 'hash_set.h', 'iterator', 'list', 'list.h', 'map', 'memory', 'new', + 'pair.h', 'pthread_alloc', 'queue', 'set', 'set.h', 'sstream', 'stack', + 'stl_alloc.h', 'stl_relops.h', 'type_traits.h', + 'utility', 'vector', 'vector.h', + ]) + + +# Non-STL C++ system headers. +_CPP_HEADERS = frozenset([ + 'algo.h', 'builtinbuf.h', 'bvector.h', 'cassert', 'cctype', + 'cerrno', 'cfloat', 'ciso646', 'climits', 'clocale', 'cmath', + 'complex', 'complex.h', 'csetjmp', 'csignal', 'cstdarg', 'cstddef', + 'cstdio', 'cstdlib', 'cstring', 'ctime', 'cwchar', 'cwctype', + 'defalloc.h', 'deque.h', 'editbuf.h', 'exception', 'fstream', + 'fstream.h', 'hashtable.h', 'heap.h', 'indstream.h', 'iomanip', + 'iomanip.h', 'ios', 'iosfwd', 'iostream', 'iostream.h', 'istream', + 'istream.h', 'iterator.h', 'limits', 'map.h', 'multimap.h', 'multiset.h', + 'numeric', 'ostream', 'ostream.h', 'parsestream.h', 'pfstream.h', + 'PlotFile.h', 'procbuf.h', 'pthread_alloc.h', 'rope', 'rope.h', + 'ropeimpl.h', 'SFile.h', 'slist', 'slist.h', 'stack.h', 'stdexcept', + 'stdiostream.h', 'streambuf.h', 'stream.h', 'strfile.h', 'string', + 'strstream', 'strstream.h', 'tempbuf.h', 'tree.h', 'typeinfo', 'valarray', + ]) + + +# Assertion macros. These are defined in base/logging.h and +# testing/base/gunit.h. Note that the _M versions need to come first +# for substring matching to work. +_CHECK_MACROS = [ + 'DCHECK', 'CHECK', + 'EXPECT_TRUE_M', 'EXPECT_TRUE', + 'ASSERT_TRUE_M', 'ASSERT_TRUE', + 'EXPECT_FALSE_M', 'EXPECT_FALSE', + 'ASSERT_FALSE_M', 'ASSERT_FALSE', + ] + +# Replacement macros for CHECK/DCHECK/EXPECT_TRUE/EXPECT_FALSE +_CHECK_REPLACEMENT = dict([(m, {}) for m in _CHECK_MACROS]) + +for op, replacement in [('==', 'EQ'), ('!=', 'NE'), + ('>=', 'GE'), ('>', 'GT'), + ('<=', 'LE'), ('<', 'LT')]: + _CHECK_REPLACEMENT['DCHECK'][op] = 'DCHECK_%s' % replacement + _CHECK_REPLACEMENT['CHECK'][op] = 'CHECK_%s' % replacement + _CHECK_REPLACEMENT['EXPECT_TRUE'][op] = 'EXPECT_%s' % replacement + _CHECK_REPLACEMENT['ASSERT_TRUE'][op] = 'ASSERT_%s' % replacement + _CHECK_REPLACEMENT['EXPECT_TRUE_M'][op] = 'EXPECT_%s_M' % replacement + _CHECK_REPLACEMENT['ASSERT_TRUE_M'][op] = 'ASSERT_%s_M' % replacement + +for op, inv_replacement in [('==', 'NE'), ('!=', 'EQ'), + ('>=', 'LT'), ('>', 'LE'), + ('<=', 'GT'), ('<', 'GE')]: + _CHECK_REPLACEMENT['EXPECT_FALSE'][op] = 'EXPECT_%s' % inv_replacement + _CHECK_REPLACEMENT['ASSERT_FALSE'][op] = 'ASSERT_%s' % inv_replacement + _CHECK_REPLACEMENT['EXPECT_FALSE_M'][op] = 'EXPECT_%s_M' % inv_replacement + _CHECK_REPLACEMENT['ASSERT_FALSE_M'][op] = 'ASSERT_%s_M' % inv_replacement + + +# These constants define types of headers for use with +# _IncludeState.CheckNextIncludeOrder(). +_C_SYS_HEADER = 1 +_CPP_SYS_HEADER = 2 +_LIKELY_MY_HEADER = 3 +_POSSIBLE_MY_HEADER = 4 +_OTHER_HEADER = 5 + + +_regexp_compile_cache = {} + +# Finds occurrences of NOLINT or NOLINT(...). +_RE_SUPPRESSION = re.compile(r'\bNOLINT\b(\([^)]*\))?') + +# {str, set(int)}: a map from error categories to sets of linenumbers +# on which those errors are expected and should be suppressed. +_error_suppressions = {} + +def ParseNolintSuppressions(filename, raw_line, linenum, error): + """Updates the global list of error-suppressions. + + Parses any NOLINT comments on the current line, updating the global + error_suppressions store. Reports an error if the NOLINT comment + was malformed. + + Args: + filename: str, the name of the input file. + raw_line: str, the line of input text, with comments. + linenum: int, the number of the current line. + error: function, an error handler. + """ + # FIXME(adonovan): "NOLINT(" is misparsed as NOLINT(*). + matched = _RE_SUPPRESSION.search(raw_line) + if matched: + category = matched.group(1) + if category in (None, '(*)'): # => "suppress all" + _error_suppressions.setdefault(None, set()).add(linenum) + else: + if category.startswith('(') and category.endswith(')'): + category = category[1:-1] + if category in _ERROR_CATEGORIES: + _error_suppressions.setdefault(category, set()).add(linenum) + else: + error(filename, linenum, 'readability/nolint', 5, + 'Unknown NOLINT error category: %s' % category) + + +def ResetNolintSuppressions(): + "Resets the set of NOLINT suppressions to empty." + _error_suppressions.clear() + + +def IsErrorSuppressedByNolint(category, linenum): + """Returns true if the specified error category is suppressed on this line. + + Consults the global error_suppressions map populated by + ParseNolintSuppressions/ResetNolintSuppressions. + + Args: + category: str, the category of the error. + linenum: int, the current line number. + Returns: + bool, True iff the error should be suppressed due to a NOLINT comment. + """ + return (linenum in _error_suppressions.get(category, set()) or + linenum in _error_suppressions.get(None, set())) + +def Match(pattern, s): + """Matches the string with the pattern, caching the compiled regexp.""" + # The regexp compilation caching is inlined in both Match and Search for + # performance reasons; factoring it out into a separate function turns out + # to be noticeably expensive. + if not pattern in _regexp_compile_cache: + _regexp_compile_cache[pattern] = sre_compile.compile(pattern) + return _regexp_compile_cache[pattern].match(s) + + +def Search(pattern, s): + """Searches the string for the pattern, caching the compiled regexp.""" + if not pattern in _regexp_compile_cache: + _regexp_compile_cache[pattern] = sre_compile.compile(pattern) + return _regexp_compile_cache[pattern].search(s) + + +class _IncludeState(dict): + """Tracks line numbers for includes, and the order in which includes appear. + + As a dict, an _IncludeState object serves as a mapping between include + filename and line number on which that file was included. + + Call CheckNextIncludeOrder() once for each header in the file, passing + in the type constants defined above. Calls in an illegal order will + raise an _IncludeError with an appropriate error message. + + """ + # self._section will move monotonically through this set. If it ever + # needs to move backwards, CheckNextIncludeOrder will raise an error. + _INITIAL_SECTION = 0 + _MY_H_SECTION = 1 + _C_SECTION = 2 + _CPP_SECTION = 3 + _OTHER_H_SECTION = 4 + + _TYPE_NAMES = { + _C_SYS_HEADER: 'C system header', + _CPP_SYS_HEADER: 'C++ system header', + _LIKELY_MY_HEADER: 'header this file implements', + _POSSIBLE_MY_HEADER: 'header this file may implement', + _OTHER_HEADER: 'other header', + } + _SECTION_NAMES = { + _INITIAL_SECTION: "... nothing. (This can't be an error.)", + _MY_H_SECTION: 'a header this file implements', + _C_SECTION: 'C system header', + _CPP_SECTION: 'C++ system header', + _OTHER_H_SECTION: 'other header', + } + + def __init__(self): + dict.__init__(self) + # The name of the current section. + self._section = self._INITIAL_SECTION + # The path of last found header. + self._last_header = '' + + def CanonicalizeAlphabeticalOrder(self, header_path): + """Returns a path canonicalized for alphabetical comparison. + + - replaces "-" with "_" so they both cmp the same. + - removes '-inl' since we don't require them to be after the main header. + - lowercase everything, just in case. + + Args: + header_path: Path to be canonicalized. + + Returns: + Canonicalized path. + """ + return header_path.replace('-inl.h', '.h').replace('-', '_').lower() + + def IsInAlphabeticalOrder(self, header_path): + """Check if a header is in alphabetical order with the previous header. + + Args: + header_path: Header to be checked. + + Returns: + Returns true if the header is in alphabetical order. + """ + canonical_header = self.CanonicalizeAlphabeticalOrder(header_path) + if self._last_header > canonical_header: + return False + self._last_header = canonical_header + return True + + def CheckNextIncludeOrder(self, header_type): + """Returns a non-empty error message if the next header is out of order. + + This function also updates the internal state to be ready to check + the next include. + + Args: + header_type: One of the _XXX_HEADER constants defined above. + + Returns: + The empty string if the header is in the right order, or an + error message describing what's wrong. + + """ + error_message = ('Found %s after %s' % + (self._TYPE_NAMES[header_type], + self._SECTION_NAMES[self._section])) + + last_section = self._section + + if header_type == _C_SYS_HEADER: + if self._section <= self._C_SECTION: + self._section = self._C_SECTION + else: + self._last_header = '' + return error_message + elif header_type == _CPP_SYS_HEADER: + if self._section <= self._CPP_SECTION: + self._section = self._CPP_SECTION + else: + self._last_header = '' + return error_message + elif header_type == _LIKELY_MY_HEADER: + if self._section <= self._MY_H_SECTION: + self._section = self._MY_H_SECTION + else: + self._section = self._OTHER_H_SECTION + elif header_type == _POSSIBLE_MY_HEADER: + if self._section <= self._MY_H_SECTION: + self._section = self._MY_H_SECTION + else: + # This will always be the fallback because we're not sure + # enough that the header is associated with this file. + self._section = self._OTHER_H_SECTION + else: + assert header_type == _OTHER_HEADER + self._section = self._OTHER_H_SECTION + + if last_section != self._section: + self._last_header = '' + + return '' + + +class _CppLintState(object): + """Maintains module-wide state..""" + + def __init__(self): + self.verbose_level = 1 # global setting. + self.error_count = 0 # global count of reported errors + # filters to apply when emitting error messages + self.filters = _DEFAULT_FILTERS[:] + self.counting = 'total' # In what way are we counting errors? + self.errors_by_category = {} # string to int dict storing error counts + + # output format: + # "emacs" - format that emacs can parse (default) + # "vs7" - format that Microsoft Visual Studio 7 can parse + self.output_format = 'emacs' + + def SetOutputFormat(self, output_format): + """Sets the output format for errors.""" + self.output_format = output_format + + def SetVerboseLevel(self, level): + """Sets the module's verbosity, and returns the previous setting.""" + last_verbose_level = self.verbose_level + self.verbose_level = level + return last_verbose_level + + def SetCountingStyle(self, counting_style): + """Sets the module's counting options.""" + self.counting = counting_style + + def SetFilters(self, filters): + """Sets the error-message filters. + + These filters are applied when deciding whether to emit a given + error message. + + Args: + filters: A string of comma-separated filters (eg "+whitespace/indent"). + Each filter should start with + or -; else we die. + + Raises: + ValueError: The comma-separated filters did not all start with '+' or '-'. + E.g. "-,+whitespace,-whitespace/indent,whitespace/badfilter" + """ + # Default filters always have less priority than the flag ones. + self.filters = _DEFAULT_FILTERS[:] + for filt in filters.split(','): + clean_filt = filt.strip() + if clean_filt: + self.filters.append(clean_filt) + for filt in self.filters: + if not (filt.startswith('+') or filt.startswith('-')): + raise ValueError('Every filter in --filters must start with + or -' + ' (%s does not)' % filt) + + def ResetErrorCounts(self): + """Sets the module's error statistic back to zero.""" + self.error_count = 0 + self.errors_by_category = {} + + def IncrementErrorCount(self, category): + """Bumps the module's error statistic.""" + self.error_count += 1 + if self.counting in ('toplevel', 'detailed'): + if self.counting != 'detailed': + category = category.split('/')[0] + if category not in self.errors_by_category: + self.errors_by_category[category] = 0 + self.errors_by_category[category] += 1 + + def PrintErrorCounts(self): + """Print a summary of errors by category, and the total.""" + for category, count in self.errors_by_category.iteritems(): + sys.stderr.write('Category \'%s\' errors found: %d\n' % + (category, count)) + sys.stderr.write('Total errors found: %d\n' % self.error_count) + +_cpplint_state = _CppLintState() + + +def _OutputFormat(): + """Gets the module's output format.""" + return _cpplint_state.output_format + + +def _SetOutputFormat(output_format): + """Sets the module's output format.""" + _cpplint_state.SetOutputFormat(output_format) + + +def _VerboseLevel(): + """Returns the module's verbosity setting.""" + return _cpplint_state.verbose_level + + +def _SetVerboseLevel(level): + """Sets the module's verbosity, and returns the previous setting.""" + return _cpplint_state.SetVerboseLevel(level) + + +def _SetCountingStyle(level): + """Sets the module's counting options.""" + _cpplint_state.SetCountingStyle(level) + + +def _Filters(): + """Returns the module's list of output filters, as a list.""" + return _cpplint_state.filters + + +def _SetFilters(filters): + """Sets the module's error-message filters. + + These filters are applied when deciding whether to emit a given + error message. + + Args: + filters: A string of comma-separated filters (eg "whitespace/indent"). + Each filter should start with + or -; else we die. + """ + _cpplint_state.SetFilters(filters) + + +class _FunctionState(object): + """Tracks current function name and the number of lines in its body.""" + + _NORMAL_TRIGGER = 250 # for --v=0, 500 for --v=1, etc. + _TEST_TRIGGER = 400 # about 50% more than _NORMAL_TRIGGER. + + def __init__(self): + self.in_a_function = False + self.lines_in_function = 0 + self.current_function = '' + + def Begin(self, function_name): + """Start analyzing function body. + + Args: + function_name: The name of the function being tracked. + """ + self.in_a_function = True + self.lines_in_function = 0 + self.current_function = function_name + + def Count(self): + """Count line in current function body.""" + if self.in_a_function: + self.lines_in_function += 1 + + def Check(self, error, filename, linenum): + """Report if too many lines in function body. + + Args: + error: The function to call with any errors found. + filename: The name of the current file. + linenum: The number of the line to check. + """ + if Match(r'T(EST|est)', self.current_function): + base_trigger = self._TEST_TRIGGER + else: + base_trigger = self._NORMAL_TRIGGER + trigger = base_trigger * 2**_VerboseLevel() + + if self.lines_in_function > trigger: + error_level = int(math.log(self.lines_in_function / base_trigger, 2)) + # 50 => 0, 100 => 1, 200 => 2, 400 => 3, 800 => 4, 1600 => 5, ... + if error_level > 5: + error_level = 5 + error(filename, linenum, 'readability/fn_size', error_level, + 'Small and focused functions are preferred:' + ' %s has %d non-comment lines' + ' (error triggered by exceeding %d lines).' % ( + self.current_function, self.lines_in_function, trigger)) + + def End(self): + """Stop analyzing function body.""" + self.in_a_function = False + + +class _IncludeError(Exception): + """Indicates a problem with the include order in a file.""" + pass + + +class FileInfo: + """Provides utility functions for filenames. + + FileInfo provides easy access to the components of a file's path + relative to the project root. + """ + + def __init__(self, filename): + self._filename = filename + + def FullName(self): + """Make Windows paths like Unix.""" + return os.path.abspath(self._filename).replace('\\', '/') + + def RepositoryName(self): + """FullName after removing the local path to the repository. + + If we have a real absolute path name here we can try to do something smart: + detecting the root of the checkout and truncating /path/to/checkout from + the name so that we get header guards that don't include things like + "C:\Documents and Settings\..." or "/home/username/..." in them and thus + people on different computers who have checked the source out to different + locations won't see bogus errors. + """ + fullname = self.FullName() + + if os.path.exists(fullname): + project_dir = os.path.dirname(fullname) + + if os.path.exists(os.path.join(project_dir, ".svn")): + # If there's a .svn file in the current directory, we recursively look + # up the directory tree for the top of the SVN checkout + root_dir = project_dir + one_up_dir = os.path.dirname(root_dir) + while os.path.exists(os.path.join(one_up_dir, ".svn")): + root_dir = os.path.dirname(root_dir) + one_up_dir = os.path.dirname(one_up_dir) + + prefix = os.path.commonprefix([root_dir, project_dir]) + return fullname[len(prefix) + 1:] + + # Not SVN <= 1.6? Try to find a git, hg, or svn top level directory by + # searching up from the current path. + root_dir = os.path.dirname(fullname) + while (root_dir != os.path.dirname(root_dir) and + not os.path.exists(os.path.join(root_dir, ".git")) and + not os.path.exists(os.path.join(root_dir, ".hg")) and + not os.path.exists(os.path.join(root_dir, ".svn"))): + root_dir = os.path.dirname(root_dir) + + if (os.path.exists(os.path.join(root_dir, ".git")) or + os.path.exists(os.path.join(root_dir, ".hg")) or + os.path.exists(os.path.join(root_dir, ".svn"))): + prefix = os.path.commonprefix([root_dir, project_dir]) + return fullname[len(prefix) + 1:] + + # Don't know what to do; header guard warnings may be wrong... + return fullname + + def Split(self): + """Splits the file into the directory, basename, and extension. + + For 'chrome/browser/browser.cc', Split() would + return ('chrome/browser', 'browser', '.cc') + + Returns: + A tuple of (directory, basename, extension). + """ + + googlename = self.RepositoryName() + project, rest = os.path.split(googlename) + return (project,) + os.path.splitext(rest) + + def BaseName(self): + """File base name - text after the final slash, before the final period.""" + return self.Split()[1] + + def Extension(self): + """File extension - text following the final period.""" + return self.Split()[2] + + def NoExtension(self): + """File has no source file extension.""" + return '/'.join(self.Split()[0:2]) + + def IsSource(self): + """File has a source file extension.""" + return self.Extension()[1:] in ('c', 'cc', 'cpp', 'cxx') + + +def _ShouldPrintError(category, confidence, linenum): + """If confidence >= verbose, category passes filter and is not suppressed.""" + + # There are three ways we might decide not to print an error message: + # a "NOLINT(category)" comment appears in the source, + # the verbosity level isn't high enough, or the filters filter it out. + if IsErrorSuppressedByNolint(category, linenum): + return False + if confidence < _cpplint_state.verbose_level: + return False + + is_filtered = False + for one_filter in _Filters(): + if one_filter.startswith('-'): + if category.startswith(one_filter[1:]): + is_filtered = True + elif one_filter.startswith('+'): + if category.startswith(one_filter[1:]): + is_filtered = False + else: + assert False # should have been checked for in SetFilter. + if is_filtered: + return False + + return True + + +def Error(filename, linenum, category, confidence, message): + """Logs the fact we've found a lint error. + + We log where the error was found, and also our confidence in the error, + that is, how certain we are this is a legitimate style regression, and + not a misidentification or a use that's sometimes justified. + + False positives can be suppressed by the use of + "cpplint(category)" comments on the offending line. These are + parsed into _error_suppressions. + + Args: + filename: The name of the file containing the error. + linenum: The number of the line containing the error. + category: A string used to describe the "category" this bug + falls under: "whitespace", say, or "runtime". Categories + may have a hierarchy separated by slashes: "whitespace/indent". + confidence: A number from 1-5 representing a confidence score for + the error, with 5 meaning that we are certain of the problem, + and 1 meaning that it could be a legitimate construct. + message: The error message. + """ + if _ShouldPrintError(category, confidence, linenum): + _cpplint_state.IncrementErrorCount(category) + if _cpplint_state.output_format == 'vs7': + sys.stderr.write('%s(%s): %s [%s] [%d]\n' % ( + filename, linenum, message, category, confidence)) + else: + sys.stderr.write('%s:%s: %s [%s] [%d]\n' % ( + filename, linenum, message, category, confidence)) + + +# Matches standard C++ escape esequences per 2.13.2.3 of the C++ standard. +_RE_PATTERN_CLEANSE_LINE_ESCAPES = re.compile( + r'\\([abfnrtv?"\\\']|\d+|x[0-9a-fA-F]+)') +# Matches strings. Escape codes should already be removed by ESCAPES. +_RE_PATTERN_CLEANSE_LINE_DOUBLE_QUOTES = re.compile(r'"[^"]*"') +# Matches characters. Escape codes should already be removed by ESCAPES. +_RE_PATTERN_CLEANSE_LINE_SINGLE_QUOTES = re.compile(r"'.'") +# Matches multi-line C++ comments. +# This RE is a little bit more complicated than one might expect, because we +# have to take care of space removals tools so we can handle comments inside +# statements better. +# The current rule is: We only clear spaces from both sides when we're at the +# end of the line. Otherwise, we try to remove spaces from the right side, +# if this doesn't work we try on left side but only if there's a non-character +# on the right. +_RE_PATTERN_CLEANSE_LINE_C_COMMENTS = re.compile( + r"""(\s*/\*.*\*/\s*$| + /\*.*\*/\s+| + \s+/\*.*\*/(?=\W)| + /\*.*\*/)""", re.VERBOSE) + + +def IsCppString(line): + """Does line terminate so, that the next symbol is in string constant. + + This function does not consider single-line nor multi-line comments. + + Args: + line: is a partial line of code starting from the 0..n. + + Returns: + True, if next character appended to 'line' is inside a + string constant. + """ + + line = line.replace(r'\\', 'XX') # after this, \\" does not match to \" + return ((line.count('"') - line.count(r'\"') - line.count("'\"'")) & 1) == 1 + + +def FindNextMultiLineCommentStart(lines, lineix): + """Find the beginning marker for a multiline comment.""" + while lineix < len(lines): + if lines[lineix].strip().startswith('/*'): + # Only return this marker if the comment goes beyond this line + if lines[lineix].strip().find('*/', 2) < 0: + return lineix + lineix += 1 + return len(lines) + + +def FindNextMultiLineCommentEnd(lines, lineix): + """We are inside a comment, find the end marker.""" + while lineix < len(lines): + if lines[lineix].strip().endswith('*/'): + return lineix + lineix += 1 + return len(lines) + + +def RemoveMultiLineCommentsFromRange(lines, begin, end): + """Clears a range of lines for multi-line comments.""" + # Having // dummy comments makes the lines non-empty, so we will not get + # unnecessary blank line warnings later in the code. + for i in range(begin, end): + lines[i] = '// dummy' + + +def RemoveMultiLineComments(filename, lines, error): + """Removes multiline (c-style) comments from lines.""" + lineix = 0 + while lineix < len(lines): + lineix_begin = FindNextMultiLineCommentStart(lines, lineix) + if lineix_begin >= len(lines): + return + lineix_end = FindNextMultiLineCommentEnd(lines, lineix_begin) + if lineix_end >= len(lines): + error(filename, lineix_begin + 1, 'readability/multiline_comment', 5, + 'Could not find end of multi-line comment') + return + RemoveMultiLineCommentsFromRange(lines, lineix_begin, lineix_end + 1) + lineix = lineix_end + 1 + + +def CleanseComments(line): + """Removes //-comments and single-line C-style /* */ comments. + + Args: + line: A line of C++ source. + + Returns: + The line with single-line comments removed. + """ + commentpos = line.find('//') + if commentpos != -1 and not IsCppString(line[:commentpos]): + line = line[:commentpos].rstrip() + # get rid of /* ... */ + return _RE_PATTERN_CLEANSE_LINE_C_COMMENTS.sub('', line) + + +class CleansedLines(object): + """Holds 3 copies of all lines with different preprocessing applied to them. + + 1) elided member contains lines without strings and comments, + 2) lines member contains lines without comments, and + 3) raw member contains all the lines without processing. + All these three members are of , and of the same length. + """ + + def __init__(self, lines): + self.elided = [] + self.lines = [] + self.raw_lines = lines + self.num_lines = len(lines) + for linenum in range(len(lines)): + self.lines.append(CleanseComments(lines[linenum])) + elided = self._CollapseStrings(lines[linenum]) + self.elided.append(CleanseComments(elided)) + + def NumLines(self): + """Returns the number of lines represented.""" + return self.num_lines + + @staticmethod + def _CollapseStrings(elided): + """Collapses strings and chars on a line to simple "" or '' blocks. + + We nix strings first so we're not fooled by text like '"http://"' + + Args: + elided: The line being processed. + + Returns: + The line with collapsed strings. + """ + if not _RE_PATTERN_INCLUDE.match(elided): + # Remove escaped characters first to make quote/single quote collapsing + # basic. Things that look like escaped characters shouldn't occur + # outside of strings and chars. + elided = _RE_PATTERN_CLEANSE_LINE_ESCAPES.sub('', elided) + elided = _RE_PATTERN_CLEANSE_LINE_SINGLE_QUOTES.sub("''", elided) + elided = _RE_PATTERN_CLEANSE_LINE_DOUBLE_QUOTES.sub('""', elided) + return elided + + +def CloseExpression(clean_lines, linenum, pos): + """If input points to ( or { or [, finds the position that closes it. + + If lines[linenum][pos] points to a '(' or '{' or '[', finds the + linenum/pos that correspond to the closing of the expression. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + pos: A position on the line. + + Returns: + A tuple (line, linenum, pos) pointer *past* the closing brace, or + (line, len(lines), -1) if we never find a close. Note we ignore + strings and comments when matching; and the line we return is the + 'cleansed' line at linenum. + """ + + line = clean_lines.elided[linenum] + startchar = line[pos] + if startchar not in '({[': + return (line, clean_lines.NumLines(), -1) + if startchar == '(': endchar = ')' + if startchar == '[': endchar = ']' + if startchar == '{': endchar = '}' + + num_open = line.count(startchar) - line.count(endchar) + while linenum < clean_lines.NumLines() and num_open > 0: + linenum += 1 + line = clean_lines.elided[linenum] + num_open += line.count(startchar) - line.count(endchar) + # OK, now find the endchar that actually got us back to even + endpos = len(line) + while num_open >= 0: + endpos = line.rfind(')', 0, endpos) + num_open -= 1 # chopped off another ) + return (line, linenum, endpos + 1) + + +def CheckForCopyright(filename, lines, error): + """Logs an error if no Copyright message appears at the top of the file.""" + + # We'll say it should occur by line 10. Don't forget there's a + # dummy line at the front. + for line in xrange(1, min(len(lines), 11)): + if re.search(r'Copyright', lines[line], re.I): break + else: # means no copyright line was found + error(filename, 0, 'legal/copyright', 5, + 'No copyright message found. ' + 'You should have a line: "Copyright [year] "') + + +def GetHeaderGuardCPPVariable(filename): + """Returns the CPP variable that should be used as a header guard. + + Args: + filename: The name of a C++ header file. + + Returns: + The CPP variable that should be used as a header guard in the + named file. + + """ + + # Restores original filename in case that cpplint is invoked from Emacs's + # flymake. + filename = re.sub(r'_flymake\.h$', '.h', filename) + + fileinfo = FileInfo(filename) + return re.sub(r'[-./\s]', '_', fileinfo.RepositoryName()).upper() + '_' + + +def CheckForHeaderGuard(filename, lines, error): + """Checks that the file contains a header guard. + + Logs an error if no #ifndef header guard is present. For other + headers, checks that the full pathname is used. + + Args: + filename: The name of the C++ header file. + lines: An array of strings, each representing a line of the file. + error: The function to call with any errors found. + """ + + cppvar = GetHeaderGuardCPPVariable(filename) + + ifndef = None + ifndef_linenum = 0 + define = None + endif = None + endif_linenum = 0 + for linenum, line in enumerate(lines): + linesplit = line.split() + if len(linesplit) >= 2: + # find the first occurrence of #ifndef and #define, save arg + if not ifndef and linesplit[0] == '#ifndef': + # set ifndef to the header guard presented on the #ifndef line. + ifndef = linesplit[1] + ifndef_linenum = linenum + if not define and linesplit[0] == '#define': + define = linesplit[1] + # find the last occurrence of #endif, save entire line + if line.startswith('#endif'): + endif = line + endif_linenum = linenum + + if not ifndef: + error(filename, 0, 'build/header_guard', 5, + 'No #ifndef header guard found, suggested CPP variable is: %s' % + cppvar) + return + + if not define: + error(filename, 0, 'build/header_guard', 5, + 'No #define header guard found, suggested CPP variable is: %s' % + cppvar) + return + + # The guard should be PATH_FILE_H_, but we also allow PATH_FILE_H__ + # for backward compatibility. + if ifndef != cppvar: + error_level = 0 + if ifndef != cppvar + '_': + error_level = 5 + + ParseNolintSuppressions(filename, lines[ifndef_linenum], ifndef_linenum, + error) + error(filename, ifndef_linenum, 'build/header_guard', error_level, + '#ifndef header guard has wrong style, please use: %s' % cppvar) + + if define != ifndef: + error(filename, 0, 'build/header_guard', 5, + '#ifndef and #define don\'t match, suggested CPP variable is: %s' % + cppvar) + return + + if endif != ('#endif // %s' % cppvar): + error_level = 0 + if endif != ('#endif // %s' % (cppvar + '_')): + error_level = 5 + + ParseNolintSuppressions(filename, lines[endif_linenum], endif_linenum, + error) + error(filename, endif_linenum, 'build/header_guard', error_level, + '#endif line should be "#endif // %s"' % cppvar) + + +def CheckForUnicodeReplacementCharacters(filename, lines, error): + """Logs an error for each line containing Unicode replacement characters. + + These indicate that either the file contained invalid UTF-8 (likely) + or Unicode replacement characters (which it shouldn't). Note that + it's possible for this to throw off line numbering if the invalid + UTF-8 occurred adjacent to a newline. + + Args: + filename: The name of the current file. + lines: An array of strings, each representing a line of the file. + error: The function to call with any errors found. + """ + for linenum, line in enumerate(lines): + if u'\ufffd' in line: + error(filename, linenum, 'readability/utf8', 5, + 'Line contains invalid UTF-8 (or Unicode replacement character).') + + +def CheckForNewlineAtEOF(filename, lines, error): + """Logs an error if there is no newline char at the end of the file. + + Args: + filename: The name of the current file. + lines: An array of strings, each representing a line of the file. + error: The function to call with any errors found. + """ + + # The array lines() was created by adding two newlines to the + # original file (go figure), then splitting on \n. + # To verify that the file ends in \n, we just have to make sure the + # last-but-two element of lines() exists and is empty. + if len(lines) < 3 or lines[-2]: + error(filename, len(lines) - 2, 'whitespace/ending_newline', 5, + 'Could not find a newline character at the end of the file.') + + +def CheckForMultilineCommentsAndStrings(filename, clean_lines, linenum, error): + """Logs an error if we see /* ... */ or "..." that extend past one line. + + /* ... */ comments are legit inside macros, for one line. + Otherwise, we prefer // comments, so it's ok to warn about the + other. Likewise, it's ok for strings to extend across multiple + lines, as long as a line continuation character (backslash) + terminates each line. Although not currently prohibited by the C++ + style guide, it's ugly and unnecessary. We don't do well with either + in this lint program, so we warn about both. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Remove all \\ (escaped backslashes) from the line. They are OK, and the + # second (escaped) slash may trigger later \" detection erroneously. + line = line.replace('\\\\', '') + + if line.count('/*') > line.count('*/'): + error(filename, linenum, 'readability/multiline_comment', 5, + 'Complex multi-line /*...*/-style comment found. ' + 'Lint may give bogus warnings. ' + 'Consider replacing these with //-style comments, ' + 'with #if 0...#endif, ' + 'or with more clearly structured multi-line comments.') + + if (line.count('"') - line.count('\\"')) % 2: + error(filename, linenum, 'readability/multiline_string', 5, + 'Multi-line string ("...") found. This lint script doesn\'t ' + 'do well with such strings, and may give bogus warnings. They\'re ' + 'ugly and unnecessary, and you should use concatenation instead".') + + +threading_list = ( + ('asctime(', 'asctime_r('), + ('ctime(', 'ctime_r('), + ('getgrgid(', 'getgrgid_r('), + ('getgrnam(', 'getgrnam_r('), + ('getlogin(', 'getlogin_r('), + ('getpwnam(', 'getpwnam_r('), + ('getpwuid(', 'getpwuid_r('), + ('gmtime(', 'gmtime_r('), + ('localtime(', 'localtime_r('), + ('rand(', 'rand_r('), + ('readdir(', 'readdir_r('), + ('strtok(', 'strtok_r('), + ('ttyname(', 'ttyname_r('), + ) + + +def CheckPosixThreading(filename, clean_lines, linenum, error): + """Checks for calls to thread-unsafe functions. + + Much code has been originally written without consideration of + multi-threading. Also, engineers are relying on their old experience; + they have learned posix before threading extensions were added. These + tests guide the engineers to use thread-safe functions (when using + posix directly). + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + for single_thread_function, multithread_safe_function in threading_list: + ix = line.find(single_thread_function) + # Comparisons made explicit for clarity -- pylint: disable-msg=C6403 + if ix >= 0 and (ix == 0 or (not line[ix - 1].isalnum() and + line[ix - 1] not in ('_', '.', '>'))): + error(filename, linenum, 'runtime/threadsafe_fn', 2, + 'Consider using ' + multithread_safe_function + + '...) instead of ' + single_thread_function + + '...) for improved thread safety.') + + +# Matches invalid increment: *count++, which moves pointer instead of +# incrementing a value. +_RE_PATTERN_INVALID_INCREMENT = re.compile( + r'^\s*\*\w+(\+\+|--);') + + +def CheckInvalidIncrement(filename, clean_lines, linenum, error): + """Checks for invalid increment *count++. + + For example following function: + void increment_counter(int* count) { + *count++; + } + is invalid, because it effectively does count++, moving pointer, and should + be replaced with ++*count, (*count)++ or *count += 1. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + if _RE_PATTERN_INVALID_INCREMENT.match(line): + error(filename, linenum, 'runtime/invalid_increment', 5, + 'Changing pointer instead of value (or unused value of operator*).') + + +class _ClassInfo(object): + """Stores information about a class.""" + + def __init__(self, name, clean_lines, linenum): + self.name = name + self.linenum = linenum + self.seen_open_brace = False + self.is_derived = False + self.virtual_method_linenumber = None + self.has_virtual_destructor = False + self.brace_depth = 0 + + # Try to find the end of the class. This will be confused by things like: + # class A { + # } *x = { ... + # + # But it's still good enough for CheckSectionSpacing. + self.last_line = 0 + depth = 0 + for i in range(linenum, clean_lines.NumLines()): + line = clean_lines.lines[i] + depth += line.count('{') - line.count('}') + if not depth: + self.last_line = i + break + + +class _ClassState(object): + """Holds the current state of the parse relating to class declarations. + + It maintains a stack of _ClassInfos representing the parser's guess + as to the current nesting of class declarations. The innermost class + is at the top (back) of the stack. Typically, the stack will either + be empty or have exactly one entry. + """ + + def __init__(self): + self.classinfo_stack = [] + + def CheckFinished(self, filename, error): + """Checks that all classes have been completely parsed. + + Call this when all lines in a file have been processed. + Args: + filename: The name of the current file. + error: The function to call with any errors found. + """ + if self.classinfo_stack: + # Note: This test can result in false positives if #ifdef constructs + # get in the way of brace matching. See the testBuildClass test in + # cpplint_unittest.py for an example of this. + error(filename, self.classinfo_stack[0].linenum, 'build/class', 5, + 'Failed to find complete declaration of class %s' % + self.classinfo_stack[0].name) + + +def CheckForNonStandardConstructs(filename, clean_lines, linenum, + class_state, error): + """Logs an error if we see certain non-ANSI constructs ignored by gcc-2. + + Complain about several constructs which gcc-2 accepts, but which are + not standard C++. Warning about these in lint is one way to ease the + transition to new compilers. + - put storage class first (e.g. "static const" instead of "const static"). + - "%lld" instead of %qd" in printf-type functions. + - "%1$d" is non-standard in printf-type functions. + - "\%" is an undefined character escape sequence. + - text after #endif is not allowed. + - invalid inner-style forward declaration. + - >? and ?= and )\?=?\s*(\w+|[+-]?\d+)(\.\d*)?', + line): + error(filename, linenum, 'build/deprecated', 3, + '>? and ))?' + # r'\s*const\s*' + type_name + '\s*&\s*\w+\s*;' + error(filename, linenum, 'runtime/member_string_references', 2, + 'const string& members are dangerous. It is much better to use ' + 'alternatives, such as pointers or simple constants.') + + # Track class entry and exit, and attempt to find cases within the + # class declaration that don't meet the C++ style + # guidelines. Tracking is very dependent on the code matching Google + # style guidelines, but it seems to perform well enough in testing + # to be a worthwhile addition to the checks. + classinfo_stack = class_state.classinfo_stack + # Look for a class declaration. The regexp accounts for decorated classes + # such as in: + # class LOCKABLE API Object { + # }; + class_decl_match = Match( + r'\s*(template\s*<[\w\s<>,:]*>\s*)?' + '(class|struct)\s+([A-Z_]+\s+)*(\w+(::\w+)*)', line) + if class_decl_match: + classinfo_stack.append(_ClassInfo( + class_decl_match.group(4), clean_lines, linenum)) + + # Everything else in this function uses the top of the stack if it's + # not empty. + if not classinfo_stack: + return + + classinfo = classinfo_stack[-1] + + # If the opening brace hasn't been seen look for it and also + # parent class declarations. + if not classinfo.seen_open_brace: + # If the line has a ';' in it, assume it's a forward declaration or + # a single-line class declaration, which we won't process. + if line.find(';') != -1: + classinfo_stack.pop() + return + classinfo.seen_open_brace = (line.find('{') != -1) + # Look for a bare ':' + if Search('(^|[^:]):($|[^:])', line): + classinfo.is_derived = True + if not classinfo.seen_open_brace: + return # Everything else in this function is for after open brace + + # The class may have been declared with namespace or classname qualifiers. + # The constructor and destructor will not have those qualifiers. + base_classname = classinfo.name.split('::')[-1] + + # Look for single-argument constructors that aren't marked explicit. + # Technically a valid construct, but against style. + args = Match(r'\s+(?:inline\s+)?%s\s*\(([^,()]+)\)' + % re.escape(base_classname), + line) + if (args and + args.group(1) != 'void' and + not Match(r'(const\s+)?%s\s*(?:<\w+>\s*)?&' % re.escape(base_classname), + args.group(1).strip())): + error(filename, linenum, 'runtime/explicit', 5, + 'Single-argument constructors should be marked explicit.') + + # Look for methods declared virtual. + if Search(r'\bvirtual\b', line): + classinfo.virtual_method_linenumber = linenum + # Only look for a destructor declaration on the same line. It would + # be extremely unlikely for the destructor declaration to occupy + # more than one line. + if Search(r'~%s\s*\(' % base_classname, line): + classinfo.has_virtual_destructor = True + + # Look for class end. + brace_depth = classinfo.brace_depth + brace_depth = brace_depth + line.count('{') - line.count('}') + if brace_depth <= 0: + classinfo = classinfo_stack.pop() + # Try to detect missing virtual destructor declarations. + # For now, only warn if a non-derived class with virtual methods lacks + # a virtual destructor. This is to make it less likely that people will + # declare derived virtual destructors without declaring the base + # destructor virtual. + if ((classinfo.virtual_method_linenumber is not None) and + (not classinfo.has_virtual_destructor) and + (not classinfo.is_derived)): # Only warn for base classes + error(filename, classinfo.linenum, 'runtime/virtual', 4, + 'The class %s probably needs a virtual destructor due to ' + 'having virtual method(s), one declared at line %d.' + % (classinfo.name, classinfo.virtual_method_linenumber)) + else: + classinfo.brace_depth = brace_depth + + +def CheckSpacingForFunctionCall(filename, line, linenum, error): + """Checks for the correctness of various spacing around function calls. + + Args: + filename: The name of the current file. + line: The text of the line to check. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + # Since function calls often occur inside if/for/while/switch + # expressions - which have their own, more liberal conventions - we + # first see if we should be looking inside such an expression for a + # function call, to which we can apply more strict standards. + fncall = line # if there's no control flow construct, look at whole line + for pattern in (r'\bif\s*\((.*)\)\s*{', + r'\bfor\s*\((.*)\)\s*{', + r'\bwhile\s*\((.*)\)\s*[{;]', + r'\bswitch\s*\((.*)\)\s*{'): + match = Search(pattern, line) + if match: + fncall = match.group(1) # look inside the parens for function calls + break + + # Except in if/for/while/switch, there should never be space + # immediately inside parens (eg "f( 3, 4 )"). We make an exception + # for nested parens ( (a+b) + c ). Likewise, there should never be + # a space before a ( when it's a function argument. I assume it's a + # function argument when the char before the whitespace is legal in + # a function name (alnum + _) and we're not starting a macro. Also ignore + # pointers and references to arrays and functions coz they're too tricky: + # we use a very simple way to recognize these: + # " (something)(maybe-something)" or + # " (something)(maybe-something," or + # " (something)[something]" + # Note that we assume the contents of [] to be short enough that + # they'll never need to wrap. + if ( # Ignore control structures. + not Search(r'\b(if|for|while|switch|return|delete)\b', fncall) and + # Ignore pointers/references to functions. + not Search(r' \([^)]+\)\([^)]*(\)|,$)', fncall) and + # Ignore pointers/references to arrays. + not Search(r' \([^)]+\)\[[^\]]+\]', fncall)): + if Search(r'\w\s*\(\s(?!\s*\\$)', fncall): # a ( used for a fn call + error(filename, linenum, 'whitespace/parens', 4, + 'Extra space after ( in function call') + elif Search(r'\(\s+(?!(\s*\\)|\()', fncall): + error(filename, linenum, 'whitespace/parens', 2, + 'Extra space after (') + if (Search(r'\w\s+\(', fncall) and + not Search(r'#\s*define|typedef', fncall)): + error(filename, linenum, 'whitespace/parens', 4, + 'Extra space before ( in function call') + # If the ) is followed only by a newline or a { + newline, assume it's + # part of a control statement (if/while/etc), and don't complain + if Search(r'[^)]\s+\)\s*[^{\s]', fncall): + # If the closing parenthesis is preceded by only whitespaces, + # try to give a more descriptive error message. + if Search(r'^\s+\)', fncall): + error(filename, linenum, 'whitespace/parens', 2, + 'Closing ) should be moved to the previous line') + else: + error(filename, linenum, 'whitespace/parens', 2, + 'Extra space before )') + + +def IsBlankLine(line): + """Returns true if the given line is blank. + + We consider a line to be blank if the line is empty or consists of + only white spaces. + + Args: + line: A line of a string. + + Returns: + True, if the given line is blank. + """ + return not line or line.isspace() + + +def CheckForFunctionLengths(filename, clean_lines, linenum, + function_state, error): + """Reports for long function bodies. + + For an overview why this is done, see: + http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Write_Short_Functions + + Uses a simplistic algorithm assuming other style guidelines + (especially spacing) are followed. + Only checks unindented functions, so class members are unchecked. + Trivial bodies are unchecked, so constructors with huge initializer lists + may be missed. + Blank/comment lines are not counted so as to avoid encouraging the removal + of vertical space and comments just to get through a lint check. + NOLINT *on the last line of a function* disables this check. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + function_state: Current function name and lines in body so far. + error: The function to call with any errors found. + """ + lines = clean_lines.lines + line = lines[linenum] + raw = clean_lines.raw_lines + raw_line = raw[linenum] + joined_line = '' + + starting_func = False + regexp = r'(\w(\w|::|\*|\&|\s)*)\(' # decls * & space::name( ... + match_result = Match(regexp, line) + if match_result: + # If the name is all caps and underscores, figure it's a macro and + # ignore it, unless it's TEST or TEST_F. + function_name = match_result.group(1).split()[-1] + if function_name == 'TEST' or function_name == 'TEST_F' or ( + not Match(r'[A-Z_]+$', function_name)): + starting_func = True + + if starting_func: + body_found = False + for start_linenum in xrange(linenum, clean_lines.NumLines()): + start_line = lines[start_linenum] + joined_line += ' ' + start_line.lstrip() + if Search(r'(;|})', start_line): # Declarations and trivial functions + body_found = True + break # ... ignore + elif Search(r'{', start_line): + body_found = True + function = Search(r'((\w|:)*)\(', line).group(1) + if Match(r'TEST', function): # Handle TEST... macros + parameter_regexp = Search(r'(\(.*\))', joined_line) + if parameter_regexp: # Ignore bad syntax + function += parameter_regexp.group(1) + else: + function += '()' + function_state.Begin(function) + break + if not body_found: + # No body for the function (or evidence of a non-function) was found. + error(filename, linenum, 'readability/fn_size', 5, + 'Lint failed to find start of function body.') + elif Match(r'^\}\s*$', line): # function end + function_state.Check(error, filename, linenum) + function_state.End() + elif not Match(r'^\s*$', line): + function_state.Count() # Count non-blank/non-comment lines. + + +_RE_PATTERN_TODO = re.compile(r'^//(\s*)TODO(\(.+?\))?:?(\s|$)?') + + +def CheckComment(comment, filename, linenum, error): + """Checks for common mistakes in TODO comments. + + Args: + comment: The text of the comment from the line in question. + filename: The name of the current file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + match = _RE_PATTERN_TODO.match(comment) + if match: + # One whitespace is correct; zero whitespace is handled elsewhere. + leading_whitespace = match.group(1) + if len(leading_whitespace) > 1: + error(filename, linenum, 'whitespace/todo', 2, + 'Too many spaces before TODO') + + username = match.group(2) + if not username: + error(filename, linenum, 'readability/todo', 2, + 'Missing username in TODO; it should look like ' + '"// TODO(my_username): Stuff."') + + middle_whitespace = match.group(3) + # Comparisons made explicit for correctness -- pylint: disable-msg=C6403 + if middle_whitespace != ' ' and middle_whitespace != '': + error(filename, linenum, 'whitespace/todo', 2, + 'TODO(my_username) should be followed by a space') + + +def CheckSpacing(filename, clean_lines, linenum, error): + """Checks for the correctness of various spacing issues in the code. + + Things we check for: spaces around operators, spaces after + if/for/while/switch, no spaces around parens in function calls, two + spaces between code and comment, don't start a block with a blank + line, don't end a function with a blank line, don't add a blank line + after public/protected/private, don't have too many blank lines in a row. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + raw = clean_lines.raw_lines + line = raw[linenum] + + # Before nixing comments, check if the line is blank for no good + # reason. This includes the first line after a block is opened, and + # blank lines at the end of a function (ie, right before a line like '}' + if IsBlankLine(line): + elided = clean_lines.elided + prev_line = elided[linenum - 1] + prevbrace = prev_line.rfind('{') + # TODO(unknown): Don't complain if line before blank line, and line after, + # both start with alnums and are indented the same amount. + # This ignores whitespace at the start of a namespace block + # because those are not usually indented. + if (prevbrace != -1 and prev_line[prevbrace:].find('}') == -1 + and prev_line[:prevbrace].find('namespace') == -1): + # OK, we have a blank line at the start of a code block. Before we + # complain, we check if it is an exception to the rule: The previous + # non-empty line has the parameters of a function header that are indented + # 4 spaces (because they did not fit in a 80 column line when placed on + # the same line as the function name). We also check for the case where + # the previous line is indented 6 spaces, which may happen when the + # initializers of a constructor do not fit into a 80 column line. + exception = False + if Match(r' {6}\w', prev_line): # Initializer list? + # We are looking for the opening column of initializer list, which + # should be indented 4 spaces to cause 6 space indentation afterwards. + search_position = linenum-2 + while (search_position >= 0 + and Match(r' {6}\w', elided[search_position])): + search_position -= 1 + exception = (search_position >= 0 + and elided[search_position][:5] == ' :') + else: + # Search for the function arguments or an initializer list. We use a + # simple heuristic here: If the line is indented 4 spaces; and we have a + # closing paren, without the opening paren, followed by an opening brace + # or colon (for initializer lists) we assume that it is the last line of + # a function header. If we have a colon indented 4 spaces, it is an + # initializer list. + exception = (Match(r' {4}\w[^\(]*\)\s*(const\s*)?(\{\s*$|:)', + prev_line) + or Match(r' {4}:', prev_line)) + + if not exception: + error(filename, linenum, 'whitespace/blank_line', 2, + 'Blank line at the start of a code block. Is this needed?') + # This doesn't ignore whitespace at the end of a namespace block + # because that is too hard without pairing open/close braces; + # however, a special exception is made for namespace closing + # brackets which have a comment containing "namespace". + # + # Also, ignore blank lines at the end of a block in a long if-else + # chain, like this: + # if (condition1) { + # // Something followed by a blank line + # + # } else if (condition2) { + # // Something else + # } + if linenum + 1 < clean_lines.NumLines(): + next_line = raw[linenum + 1] + if (next_line + and Match(r'\s*}', next_line) + and next_line.find('namespace') == -1 + and next_line.find('} else ') == -1): + error(filename, linenum, 'whitespace/blank_line', 3, + 'Blank line at the end of a code block. Is this needed?') + + matched = Match(r'\s*(public|protected|private):', prev_line) + if matched: + error(filename, linenum, 'whitespace/blank_line', 3, + 'Do not leave a blank line after "%s:"' % matched.group(1)) + + # Next, we complain if there's a comment too near the text + commentpos = line.find('//') + if commentpos != -1: + # Check if the // may be in quotes. If so, ignore it + # Comparisons made explicit for clarity -- pylint: disable-msg=C6403 + if (line.count('"', 0, commentpos) - + line.count('\\"', 0, commentpos)) % 2 == 0: # not in quotes + # Allow one space for new scopes, two spaces otherwise: + if (not Match(r'^\s*{ //', line) and + ((commentpos >= 1 and + line[commentpos-1] not in string.whitespace) or + (commentpos >= 2 and + line[commentpos-2] not in string.whitespace))): + error(filename, linenum, 'whitespace/comments', 2, + 'At least two spaces is best between code and comments') + # There should always be a space between the // and the comment + commentend = commentpos + 2 + if commentend < len(line) and not line[commentend] == ' ': + # but some lines are exceptions -- e.g. if they're big + # comment delimiters like: + # //---------------------------------------------------------- + # or are an empty C++ style Doxygen comment, like: + # /// + # or they begin with multiple slashes followed by a space: + # //////// Header comment + match = (Search(r'[=/-]{4,}\s*$', line[commentend:]) or + Search(r'^/$', line[commentend:]) or + Search(r'^/+ ', line[commentend:])) + if not match: + error(filename, linenum, 'whitespace/comments', 4, + 'Should have a space between // and comment') + CheckComment(line[commentpos:], filename, linenum, error) + + line = clean_lines.elided[linenum] # get rid of comments and strings + + # Don't try to do spacing checks for operator methods + line = re.sub(r'operator(==|!=|<|<<|<=|>=|>>|>)\(', 'operator\(', line) + + # We allow no-spaces around = within an if: "if ( (a=Foo()) == 0 )". + # Otherwise not. Note we only check for non-spaces on *both* sides; + # sometimes people put non-spaces on one side when aligning ='s among + # many lines (not that this is behavior that I approve of...) + if Search(r'[\w.]=[\w.]', line) and not Search(r'\b(if|while) ', line): + error(filename, linenum, 'whitespace/operators', 4, + 'Missing spaces around =') + + # It's ok not to have spaces around binary operators like + - * /, but if + # there's too little whitespace, we get concerned. It's hard to tell, + # though, so we punt on this one for now. TODO. + + # You should always have whitespace around binary operators. + # Alas, we can't test < or > because they're legitimately used sans spaces + # (a->b, vector a). The only time we can tell is a < with no >, and + # only if it's not template params list spilling into the next line. + match = Search(r'[^<>=!\s](==|!=|<=|>=)[^<>=!\s]', line) + if not match: + # Note that while it seems that the '<[^<]*' term in the following + # regexp could be simplified to '<.*', which would indeed match + # the same class of strings, the [^<] means that searching for the + # regexp takes linear rather than quadratic time. + if not Search(r'<[^<]*,\s*$', line): # template params spill + match = Search(r'[^<>=!\s](<)[^<>=!\s]([^>]|->)*$', line) + if match: + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around %s' % match.group(1)) + # We allow no-spaces around << and >> when used like this: 10<<20, but + # not otherwise (particularly, not when used as streams) + match = Search(r'[^0-9\s](<<|>>)[^0-9\s]', line) + if match: + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around %s' % match.group(1)) + + # There shouldn't be space around unary operators + match = Search(r'(!\s|~\s|[\s]--[\s;]|[\s]\+\+[\s;])', line) + if match: + error(filename, linenum, 'whitespace/operators', 4, + 'Extra space for operator %s' % match.group(1)) + + # A pet peeve of mine: no spaces after an if, while, switch, or for + match = Search(r' (if\(|for\(|while\(|switch\()', line) + if match: + error(filename, linenum, 'whitespace/parens', 5, + 'Missing space before ( in %s' % match.group(1)) + + # For if/for/while/switch, the left and right parens should be + # consistent about how many spaces are inside the parens, and + # there should either be zero or one spaces inside the parens. + # We don't want: "if ( foo)" or "if ( foo )". + # Exception: "for ( ; foo; bar)" and "for (foo; bar; )" are allowed. + match = Search(r'\b(if|for|while|switch)\s*' + r'\(([ ]*)(.).*[^ ]+([ ]*)\)\s*{\s*$', + line) + if match: + if len(match.group(2)) != len(match.group(4)): + if not (match.group(3) == ';' and + len(match.group(2)) == 1 + len(match.group(4)) or + not match.group(2) and Search(r'\bfor\s*\(.*; \)', line)): + error(filename, linenum, 'whitespace/parens', 5, + 'Mismatching spaces inside () in %s' % match.group(1)) + if not len(match.group(2)) in [0, 1]: + error(filename, linenum, 'whitespace/parens', 5, + 'Should have zero or one spaces inside ( and ) in %s' % + match.group(1)) + + # You should always have a space after a comma (either as fn arg or operator) + if Search(r',[^\s]', line): + error(filename, linenum, 'whitespace/comma', 3, + 'Missing space after ,') + + # You should always have a space after a semicolon + # except for few corner cases + # TODO(unknown): clarify if 'if (1) { return 1;}' is requires one more + # space after ; + if Search(r';[^\s};\\)/]', line): + error(filename, linenum, 'whitespace/semicolon', 3, + 'Missing space after ;') + + # Next we will look for issues with function calls. + CheckSpacingForFunctionCall(filename, line, linenum, error) + + # Except after an opening paren, or after another opening brace (in case of + # an initializer list, for instance), you should have spaces before your + # braces. And since you should never have braces at the beginning of a line, + # this is an easy test. + if Search(r'[^ ({]{', line): + error(filename, linenum, 'whitespace/braces', 5, + 'Missing space before {') + + # Make sure '} else {' has spaces. + if Search(r'}else', line): + error(filename, linenum, 'whitespace/braces', 5, + 'Missing space before else') + + # You shouldn't have spaces before your brackets, except maybe after + # 'delete []' or 'new char * []'. + if Search(r'\w\s+\[', line) and not Search(r'delete\s+\[', line): + error(filename, linenum, 'whitespace/braces', 5, + 'Extra space before [') + + # You shouldn't have a space before a semicolon at the end of the line. + # There's a special case for "for" since the style guide allows space before + # the semicolon there. + if Search(r':\s*;\s*$', line): + error(filename, linenum, 'whitespace/semicolon', 5, + 'Semicolon defining empty statement. Use { } instead.') + elif Search(r'^\s*;\s*$', line): + error(filename, linenum, 'whitespace/semicolon', 5, + 'Line contains only semicolon. If this should be an empty statement, ' + 'use { } instead.') + elif (Search(r'\s+;\s*$', line) and + not Search(r'\bfor\b', line)): + error(filename, linenum, 'whitespace/semicolon', 5, + 'Extra space before last semicolon. If this should be an empty ' + 'statement, use { } instead.') + + +def CheckSectionSpacing(filename, clean_lines, class_info, linenum, error): + """Checks for additional blank line issues related to sections. + + Currently the only thing checked here is blank line before protected/private. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + class_info: A _ClassInfo objects. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + # Skip checks if the class is small, where small means 25 lines or less. + # 25 lines seems like a good cutoff since that's the usual height of + # terminals, and any class that can't fit in one screen can't really + # be considered "small". + # + # Also skip checks if we are on the first line. This accounts for + # classes that look like + # class Foo { public: ... }; + # + # If we didn't find the end of the class, last_line would be zero, + # and the check will be skipped by the first condition. + if (class_info.last_line - class_info.linenum <= 24 or + linenum <= class_info.linenum): + return + + matched = Match(r'\s*(public|protected|private):', clean_lines.lines[linenum]) + if matched: + # Issue warning if the line before public/protected/private was + # not a blank line, but don't do this if the previous line contains + # "class" or "struct". This can happen two ways: + # - We are at the beginning of the class. + # - We are forward-declaring an inner class that is semantically + # private, but needed to be public for implementation reasons. + prev_line = clean_lines.lines[linenum - 1] + if (not IsBlankLine(prev_line) and + not Search(r'\b(class|struct)\b', prev_line)): + # Try a bit harder to find the beginning of the class. This is to + # account for multi-line base-specifier lists, e.g.: + # class Derived + # : public Base { + end_class_head = class_info.linenum + for i in range(class_info.linenum, linenum): + if Search(r'\{\s*$', clean_lines.lines[i]): + end_class_head = i + break + if end_class_head < linenum - 1: + error(filename, linenum, 'whitespace/blank_line', 3, + '"%s:" should be preceded by a blank line' % matched.group(1)) + + +def GetPreviousNonBlankLine(clean_lines, linenum): + """Return the most recent non-blank line and its line number. + + Args: + clean_lines: A CleansedLines instance containing the file contents. + linenum: The number of the line to check. + + Returns: + A tuple with two elements. The first element is the contents of the last + non-blank line before the current line, or the empty string if this is the + first non-blank line. The second is the line number of that line, or -1 + if this is the first non-blank line. + """ + + prevlinenum = linenum - 1 + while prevlinenum >= 0: + prevline = clean_lines.elided[prevlinenum] + if not IsBlankLine(prevline): # if not a blank line... + return (prevline, prevlinenum) + prevlinenum -= 1 + return ('', -1) + + +def CheckBraces(filename, clean_lines, linenum, error): + """Looks for misplaced braces (e.g. at the end of line). + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + line = clean_lines.elided[linenum] # get rid of comments and strings + + if Match(r'\s*{\s*$', line): + # We allow an open brace to start a line in the case where someone + # is using braces in a block to explicitly create a new scope, + # which is commonly used to control the lifetime of + # stack-allocated variables. We don't detect this perfectly: we + # just don't complain if the last non-whitespace character on the + # previous non-blank line is ';', ':', '{', or '}'. + prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0] + if not Search(r'[;:}{]\s*$', prevline): + error(filename, linenum, 'whitespace/braces', 4, + '{ should almost always be at the end of the previous line') + + # An else clause should be on the same line as the preceding closing brace. + if Match(r'\s*else\s*', line): + prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0] + if Match(r'\s*}\s*$', prevline): + error(filename, linenum, 'whitespace/newline', 4, + 'An else should appear on the same line as the preceding }') + + # If braces come on one side of an else, they should be on both. + # However, we have to worry about "else if" that spans multiple lines! + if Search(r'}\s*else[^{]*$', line) or Match(r'[^}]*else\s*{', line): + if Search(r'}\s*else if([^{]*)$', line): # could be multi-line if + # find the ( after the if + pos = line.find('else if') + pos = line.find('(', pos) + if pos > 0: + (endline, _, endpos) = CloseExpression(clean_lines, linenum, pos) + if endline[endpos:].find('{') == -1: # must be brace after if + error(filename, linenum, 'readability/braces', 5, + 'If an else has a brace on one side, it should have it on both') + else: # common case: else not followed by a multi-line if + error(filename, linenum, 'readability/braces', 5, + 'If an else has a brace on one side, it should have it on both') + + # Likewise, an else should never have the else clause on the same line + if Search(r'\belse [^\s{]', line) and not Search(r'\belse if\b', line): + error(filename, linenum, 'whitespace/newline', 4, + 'Else clause should never be on same line as else (use 2 lines)') + + # In the same way, a do/while should never be on one line + if Match(r'\s*do [^\s{]', line): + error(filename, linenum, 'whitespace/newline', 4, + 'do/while clauses should not be on a single line') + + # Braces shouldn't be followed by a ; unless they're defining a struct + # or initializing an array. + # We can't tell in general, but we can for some common cases. + prevlinenum = linenum + while True: + (prevline, prevlinenum) = GetPreviousNonBlankLine(clean_lines, prevlinenum) + if Match(r'\s+{.*}\s*;', line) and not prevline.count(';'): + line = prevline + line + else: + break + if (Search(r'{.*}\s*;', line) and + line.count('{') == line.count('}') and + not Search(r'struct|class|enum|\s*=\s*{', line)): + error(filename, linenum, 'readability/braces', 4, + "You don't need a ; after a }") + + +def ReplaceableCheck(operator, macro, line): + """Determine whether a basic CHECK can be replaced with a more specific one. + + For example suggest using CHECK_EQ instead of CHECK(a == b) and + similarly for CHECK_GE, CHECK_GT, CHECK_LE, CHECK_LT, CHECK_NE. + + Args: + operator: The C++ operator used in the CHECK. + macro: The CHECK or EXPECT macro being called. + line: The current source line. + + Returns: + True if the CHECK can be replaced with a more specific one. + """ + + # This matches decimal and hex integers, strings, and chars (in that order). + match_constant = r'([-+]?(\d+|0[xX][0-9a-fA-F]+)[lLuU]{0,3}|".*"|\'.*\')' + + # Expression to match two sides of the operator with something that + # looks like a literal, since CHECK(x == iterator) won't compile. + # This means we can't catch all the cases where a more specific + # CHECK is possible, but it's less annoying than dealing with + # extraneous warnings. + match_this = (r'\s*' + macro + r'\((\s*' + + match_constant + r'\s*' + operator + r'[^<>].*|' + r'.*[^<>]' + operator + r'\s*' + match_constant + + r'\s*\))') + + # Don't complain about CHECK(x == NULL) or similar because + # CHECK_EQ(x, NULL) won't compile (requires a cast). + # Also, don't complain about more complex boolean expressions + # involving && or || such as CHECK(a == b || c == d). + return Match(match_this, line) and not Search(r'NULL|&&|\|\|', line) + + +def CheckCheck(filename, clean_lines, linenum, error): + """Checks the use of CHECK and EXPECT macros. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + # Decide the set of replacement macros that should be suggested + raw_lines = clean_lines.raw_lines + current_macro = '' + for macro in _CHECK_MACROS: + if raw_lines[linenum].find(macro) >= 0: + current_macro = macro + break + if not current_macro: + # Don't waste time here if line doesn't contain 'CHECK' or 'EXPECT' + return + + line = clean_lines.elided[linenum] # get rid of comments and strings + + # Encourage replacing plain CHECKs with CHECK_EQ/CHECK_NE/etc. + for operator in ['==', '!=', '>=', '>', '<=', '<']: + if ReplaceableCheck(operator, current_macro, line): + error(filename, linenum, 'readability/check', 2, + 'Consider using %s instead of %s(a %s b)' % ( + _CHECK_REPLACEMENT[current_macro][operator], + current_macro, operator)) + break + + +def GetLineWidth(line): + """Determines the width of the line in column positions. + + Args: + line: A string, which may be a Unicode string. + + Returns: + The width of the line in column positions, accounting for Unicode + combining characters and wide characters. + """ + if isinstance(line, unicode): + width = 0 + for uc in unicodedata.normalize('NFC', line): + if unicodedata.east_asian_width(uc) in ('W', 'F'): + width += 2 + elif not unicodedata.combining(uc): + width += 1 + return width + else: + return len(line) + + +def CheckStyle(filename, clean_lines, linenum, file_extension, class_state, + error): + """Checks rules from the 'C++ style rules' section of cppguide.html. + + Most of these rules are hard to test (naming, comment style), but we + do what we can. In particular we check for 2-space indents, line lengths, + tab usage, spaces inside code, etc. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + file_extension: The extension (without the dot) of the filename. + error: The function to call with any errors found. + """ + + raw_lines = clean_lines.raw_lines + line = raw_lines[linenum] + + if line.find('\t') != -1: + error(filename, linenum, 'whitespace/tab', 1, + 'Tab found; better to use spaces') + + # One or three blank spaces at the beginning of the line is weird; it's + # hard to reconcile that with 2-space indents. + # NOTE: here are the conditions rob pike used for his tests. Mine aren't + # as sophisticated, but it may be worth becoming so: RLENGTH==initial_spaces + # if(RLENGTH > 20) complain = 0; + # if(match($0, " +(error|private|public|protected):")) complain = 0; + # if(match(prev, "&& *$")) complain = 0; + # if(match(prev, "\\|\\| *$")) complain = 0; + # if(match(prev, "[\",=><] *$")) complain = 0; + # if(match($0, " <<")) complain = 0; + # if(match(prev, " +for \\(")) complain = 0; + # if(prevodd && match(prevprev, " +for \\(")) complain = 0; + initial_spaces = 0 + cleansed_line = clean_lines.elided[linenum] + while initial_spaces < len(line) and line[initial_spaces] == ' ': + initial_spaces += 1 + if line and line[-1].isspace(): + error(filename, linenum, 'whitespace/end_of_line', 4, + 'Line ends in whitespace. Consider deleting these extra spaces.') + # There are certain situations we allow one space, notably for labels + elif ((initial_spaces == 1 or initial_spaces == 3) and + not Match(r'\s*\w+\s*:\s*$', cleansed_line)): + error(filename, linenum, 'whitespace/indent', 3, + 'Weird number of spaces at line-start. ' + 'Are you using a 2-space indent?') + # Labels should always be indented at least one space. + elif not initial_spaces and line[:2] != '//' and Search(r'[^:]:\s*$', + line): + error(filename, linenum, 'whitespace/labels', 4, + 'Labels should always be indented at least one space. ' + 'If this is a member-initializer list in a constructor or ' + 'the base class list in a class definition, the colon should ' + 'be on the following line.') + + + # Check if the line is a header guard. + is_header_guard = False + if file_extension == 'h': + cppvar = GetHeaderGuardCPPVariable(filename) + if (line.startswith('#ifndef %s' % cppvar) or + line.startswith('#define %s' % cppvar) or + line.startswith('#endif // %s' % cppvar)): + is_header_guard = True + # #include lines and header guards can be long, since there's no clean way to + # split them. + # + # URLs can be long too. It's possible to split these, but it makes them + # harder to cut&paste. + # + # The "$Id:...$" comment may also get very long without it being the + # developers fault. + if (not line.startswith('#include') and not is_header_guard and + not Match(r'^\s*//.*http(s?)://\S*$', line) and + not Match(r'^// \$Id:.*#[0-9]+ \$$', line)): + line_width = GetLineWidth(line) + if line_width > 100: + error(filename, linenum, 'whitespace/line_length', 4, + 'Lines should very rarely be longer than 100 characters') + elif line_width > 80: + error(filename, linenum, 'whitespace/line_length', 2, + 'Lines should be <= 80 characters long') + + if (cleansed_line.count(';') > 1 and + # for loops are allowed two ;'s (and may run over two lines). + cleansed_line.find('for') == -1 and + (GetPreviousNonBlankLine(clean_lines, linenum)[0].find('for') == -1 or + GetPreviousNonBlankLine(clean_lines, linenum)[0].find(';') != -1) and + # It's ok to have many commands in a switch case that fits in 1 line + not ((cleansed_line.find('case ') != -1 or + cleansed_line.find('default:') != -1) and + cleansed_line.find('break;') != -1)): + error(filename, linenum, 'whitespace/newline', 4, + 'More than one command on the same line') + + # Some more style checks + CheckBraces(filename, clean_lines, linenum, error) + CheckSpacing(filename, clean_lines, linenum, error) + CheckCheck(filename, clean_lines, linenum, error) + if class_state and class_state.classinfo_stack: + CheckSectionSpacing(filename, clean_lines, + class_state.classinfo_stack[-1], linenum, error) + + +_RE_PATTERN_INCLUDE_NEW_STYLE = re.compile(r'#include +"[^/]+\.h"') +_RE_PATTERN_INCLUDE = re.compile(r'^\s*#\s*include\s*([<"])([^>"]*)[>"].*$') +# Matches the first component of a filename delimited by -s and _s. That is: +# _RE_FIRST_COMPONENT.match('foo').group(0) == 'foo' +# _RE_FIRST_COMPONENT.match('foo.cc').group(0) == 'foo' +# _RE_FIRST_COMPONENT.match('foo-bar_baz.cc').group(0) == 'foo' +# _RE_FIRST_COMPONENT.match('foo_bar-baz.cc').group(0) == 'foo' +_RE_FIRST_COMPONENT = re.compile(r'^[^-_.]+') + + +def _DropCommonSuffixes(filename): + """Drops common suffixes like _test.cc or -inl.h from filename. + + For example: + >>> _DropCommonSuffixes('foo/foo-inl.h') + 'foo/foo' + >>> _DropCommonSuffixes('foo/bar/foo.cc') + 'foo/bar/foo' + >>> _DropCommonSuffixes('foo/foo_internal.h') + 'foo/foo' + >>> _DropCommonSuffixes('foo/foo_unusualinternal.h') + 'foo/foo_unusualinternal' + + Args: + filename: The input filename. + + Returns: + The filename with the common suffix removed. + """ + for suffix in ('test.cc', 'regtest.cc', 'unittest.cc', + 'inl.h', 'impl.h', 'internal.h'): + if (filename.endswith(suffix) and len(filename) > len(suffix) and + filename[-len(suffix) - 1] in ('-', '_')): + return filename[:-len(suffix) - 1] + return os.path.splitext(filename)[0] + + +def _IsTestFilename(filename): + """Determines if the given filename has a suffix that identifies it as a test. + + Args: + filename: The input filename. + + Returns: + True if 'filename' looks like a test, False otherwise. + """ + if (filename.endswith('_test.cc') or + filename.endswith('_unittest.cc') or + filename.endswith('_regtest.cc')): + return True + else: + return False + + +def _ClassifyInclude(fileinfo, include, is_system): + """Figures out what kind of header 'include' is. + + Args: + fileinfo: The current file cpplint is running over. A FileInfo instance. + include: The path to a #included file. + is_system: True if the #include used <> rather than "". + + Returns: + One of the _XXX_HEADER constants. + + For example: + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'stdio.h', True) + _C_SYS_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'string', True) + _CPP_SYS_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'foo/foo.h', False) + _LIKELY_MY_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo_unknown_extension.cc'), + ... 'bar/foo_other_ext.h', False) + _POSSIBLE_MY_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'foo/bar.h', False) + _OTHER_HEADER + """ + # This is a list of all standard c++ header files, except + # those already checked for above. + is_stl_h = include in _STL_HEADERS + is_cpp_h = is_stl_h or include in _CPP_HEADERS + + if is_system: + if is_cpp_h: + return _CPP_SYS_HEADER + else: + return _C_SYS_HEADER + + # If the target file and the include we're checking share a + # basename when we drop common extensions, and the include + # lives in . , then it's likely to be owned by the target file. + target_dir, target_base = ( + os.path.split(_DropCommonSuffixes(fileinfo.RepositoryName()))) + include_dir, include_base = os.path.split(_DropCommonSuffixes(include)) + if target_base == include_base and ( + include_dir == target_dir or + include_dir == os.path.normpath(target_dir + '/../public')): + return _LIKELY_MY_HEADER + + # If the target and include share some initial basename + # component, it's possible the target is implementing the + # include, so it's allowed to be first, but we'll never + # complain if it's not there. + target_first_component = _RE_FIRST_COMPONENT.match(target_base) + include_first_component = _RE_FIRST_COMPONENT.match(include_base) + if (target_first_component and include_first_component and + target_first_component.group(0) == + include_first_component.group(0)): + return _POSSIBLE_MY_HEADER + + return _OTHER_HEADER + + + +def CheckIncludeLine(filename, clean_lines, linenum, include_state, error): + """Check rules that are applicable to #include lines. + + Strings on #include lines are NOT removed from elided line, to make + certain tasks easier. However, to prevent false positives, checks + applicable to #include lines in CheckLanguage must be put here. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + include_state: An _IncludeState instance in which the headers are inserted. + error: The function to call with any errors found. + """ + fileinfo = FileInfo(filename) + + line = clean_lines.lines[linenum] + + # "include" should use the new style "foo/bar.h" instead of just "bar.h" + if _RE_PATTERN_INCLUDE_NEW_STYLE.search(line): + error(filename, linenum, 'build/include', 4, + 'Include the directory when naming .h files') + + # we shouldn't include a file more than once. actually, there are a + # handful of instances where doing so is okay, but in general it's + # not. + match = _RE_PATTERN_INCLUDE.search(line) + if match: + include = match.group(2) + is_system = (match.group(1) == '<') + if include in include_state: + error(filename, linenum, 'build/include', 4, + '"%s" already included at %s:%s' % + (include, filename, include_state[include])) + else: + include_state[include] = linenum + + # We want to ensure that headers appear in the right order: + # 1) for foo.cc, foo.h (preferred location) + # 2) c system files + # 3) cpp system files + # 4) for foo.cc, foo.h (deprecated location) + # 5) other google headers + # + # We classify each include statement as one of those 5 types + # using a number of techniques. The include_state object keeps + # track of the highest type seen, and complains if we see a + # lower type after that. + error_message = include_state.CheckNextIncludeOrder( + _ClassifyInclude(fileinfo, include, is_system)) + if error_message: + error(filename, linenum, 'build/include_order', 4, + '%s. Should be: %s.h, c system, c++ system, other.' % + (error_message, fileinfo.BaseName())) + if not include_state.IsInAlphabeticalOrder(include): + error(filename, linenum, 'build/include_alpha', 4, + 'Include "%s" not in alphabetical order' % include) + + # Look for any of the stream classes that are part of standard C++. + match = _RE_PATTERN_INCLUDE.match(line) + if match: + include = match.group(2) + if Match(r'(f|ind|io|i|o|parse|pf|stdio|str|)?stream$', include): + # Many unit tests use cout, so we exempt them. + if not _IsTestFilename(filename): + error(filename, linenum, 'readability/streams', 3, + 'Streams are highly discouraged.') + + +def _GetTextInside(text, start_pattern): + """Retrieves all the text between matching open and close parentheses. + + Given a string of lines and a regular expression string, retrieve all the text + following the expression and between opening punctuation symbols like + (, [, or {, and the matching close-punctuation symbol. This properly nested + occurrences of the punctuations, so for the text like + printf(a(), b(c())); + a call to _GetTextInside(text, r'printf\(') will return 'a(), b(c())'. + start_pattern must match string having an open punctuation symbol at the end. + + Args: + text: The lines to extract text. Its comments and strings must be elided. + It can be single line and can span multiple lines. + start_pattern: The regexp string indicating where to start extracting + the text. + Returns: + The extracted text. + None if either the opening string or ending punctuation could not be found. + """ + # TODO(sugawarayu): Audit cpplint.py to see what places could be profitably + # rewritten to use _GetTextInside (and use inferior regexp matching today). + + # Give opening punctuations to get the matching close-punctuations. + matching_punctuation = {'(': ')', '{': '}', '[': ']'} + closing_punctuation = set(matching_punctuation.itervalues()) + + # Find the position to start extracting text. + match = re.search(start_pattern, text, re.M) + if not match: # start_pattern not found in text. + return None + start_position = match.end(0) + + assert start_position > 0, ( + 'start_pattern must ends with an opening punctuation.') + assert text[start_position - 1] in matching_punctuation, ( + 'start_pattern must ends with an opening punctuation.') + # Stack of closing punctuations we expect to have in text after position. + punctuation_stack = [matching_punctuation[text[start_position - 1]]] + position = start_position + while punctuation_stack and position < len(text): + if text[position] == punctuation_stack[-1]: + punctuation_stack.pop() + elif text[position] in closing_punctuation: + # A closing punctuation without matching opening punctuations. + return None + elif text[position] in matching_punctuation: + punctuation_stack.append(matching_punctuation[text[position]]) + position += 1 + if punctuation_stack: + # Opening punctuations left without matching close-punctuations. + return None + # punctuations match. + return text[start_position:position - 1] + + +def CheckLanguage(filename, clean_lines, linenum, file_extension, include_state, + error): + """Checks rules from the 'C++ language rules' section of cppguide.html. + + Some of these rules are hard to test (function overloading, using + uint32 inappropriately), but we do the best we can. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + file_extension: The extension (without the dot) of the filename. + include_state: An _IncludeState instance in which the headers are inserted. + error: The function to call with any errors found. + """ + # If the line is empty or consists of entirely a comment, no need to + # check it. + line = clean_lines.elided[linenum] + if not line: + return + + match = _RE_PATTERN_INCLUDE.search(line) + if match: + CheckIncludeLine(filename, clean_lines, linenum, include_state, error) + return + + # Create an extended_line, which is the concatenation of the current and + # next lines, for more effective checking of code that may span more than one + # line. + if linenum + 1 < clean_lines.NumLines(): + extended_line = line + clean_lines.elided[linenum + 1] + else: + extended_line = line + + # Make Windows paths like Unix. + fullname = os.path.abspath(filename).replace('\\', '/') + + # TODO(unknown): figure out if they're using default arguments in fn proto. + + # Check for non-const references in functions. This is tricky because & + # is also used to take the address of something. We allow <> for templates, + # (ignoring whatever is between the braces) and : for classes. + # These are complicated re's. They try to capture the following: + # paren (for fn-prototype start), typename, &, varname. For the const + # version, we're willing for const to be before typename or after + # Don't check the implementation on same line. + fnline = line.split('{', 1)[0] + if (len(re.findall(r'\([^()]*\b(?:[\w:]|<[^()]*>)+(\s?&|&\s?)\w+', fnline)) > + len(re.findall(r'\([^()]*\bconst\s+(?:typename\s+)?(?:struct\s+)?' + r'(?:[\w:]|<[^()]*>)+(\s?&|&\s?)\w+', fnline)) + + len(re.findall(r'\([^()]*\b(?:[\w:]|<[^()]*>)+\s+const(\s?&|&\s?)[\w]+', + fnline))): + + # We allow non-const references in a few standard places, like functions + # called "swap()" or iostream operators like "<<" or ">>". + if not Search( + r'(swap|Swap|operator[<>][<>])\s*\(\s*(?:[\w:]|<.*>)+\s*&', + fnline): + error(filename, linenum, 'runtime/references', 2, + 'Is this a non-const reference? ' + 'If so, make const or use a pointer.') + + # Check to see if they're using an conversion function cast. + # I just try to capture the most common basic types, though there are more. + # Parameterless conversion functions, such as bool(), are allowed as they are + # probably a member operator declaration or default constructor. + match = Search( + r'(\bnew\s+)?\b' # Grab 'new' operator, if it's there + r'(int|float|double|bool|char|int32|uint32|int64|uint64)\([^)]', line) + if match: + # gMock methods are defined using some variant of MOCK_METHODx(name, type) + # where type may be float(), int(string), etc. Without context they are + # virtually indistinguishable from int(x) casts. Likewise, gMock's + # MockCallback takes a template parameter of the form return_type(arg_type), + # which looks much like the cast we're trying to detect. + if (match.group(1) is None and # If new operator, then this isn't a cast + not (Match(r'^\s*MOCK_(CONST_)?METHOD\d+(_T)?\(', line) or + Match(r'^\s*MockCallback<.*>', line))): + error(filename, linenum, 'readability/casting', 4, + 'Using deprecated casting style. ' + 'Use static_cast<%s>(...) instead' % + match.group(2)) + + CheckCStyleCast(filename, linenum, line, clean_lines.raw_lines[linenum], + 'static_cast', + r'\((int|float|double|bool|char|u?int(16|32|64))\)', error) + + # This doesn't catch all cases. Consider (const char * const)"hello". + # + # (char *) "foo" should always be a const_cast (reinterpret_cast won't + # compile). + if CheckCStyleCast(filename, linenum, line, clean_lines.raw_lines[linenum], + 'const_cast', r'\((char\s?\*+\s?)\)\s*"', error): + pass + else: + # Check pointer casts for other than string constants + CheckCStyleCast(filename, linenum, line, clean_lines.raw_lines[linenum], + 'reinterpret_cast', r'\((\w+\s?\*+\s?)\)', error) + + # In addition, we look for people taking the address of a cast. This + # is dangerous -- casts can assign to temporaries, so the pointer doesn't + # point where you think. + if Search( + r'(&\([^)]+\)[\w(])|(&(static|dynamic|reinterpret)_cast\b)', line): + error(filename, linenum, 'runtime/casting', 4, + ('Are you taking an address of a cast? ' + 'This is dangerous: could be a temp var. ' + 'Take the address before doing the cast, rather than after')) + + # Check for people declaring static/global STL strings at the top level. + # This is dangerous because the C++ language does not guarantee that + # globals with constructors are initialized before the first access. + match = Match( + r'((?:|static +)(?:|const +))string +([a-zA-Z0-9_:]+)\b(.*)', + line) + # Make sure it's not a function. + # Function template specialization looks like: "string foo(...". + # Class template definitions look like: "string Foo::Method(...". + if match and not Match(r'\s*(<.*>)?(::[a-zA-Z0-9_]+)?\s*\(([^"]|$)', + match.group(3)): + error(filename, linenum, 'runtime/string', 4, + 'For a static/global string constant, use a C style string instead: ' + '"%schar %s[]".' % + (match.group(1), match.group(2))) + + # Check that we're not using RTTI outside of testing code. + if Search(r'\bdynamic_cast<', line) and not _IsTestFilename(filename): + error(filename, linenum, 'runtime/rtti', 5, + 'Do not use dynamic_cast<>. If you need to cast within a class ' + "hierarchy, use static_cast<> to upcast. Google doesn't support " + 'RTTI.') + + if Search(r'\b([A-Za-z0-9_]*_)\(\1\)', line): + error(filename, linenum, 'runtime/init', 4, + 'You seem to be initializing a member variable with itself.') + + if file_extension == 'h': + # TODO(unknown): check that 1-arg constructors are explicit. + # How to tell it's a constructor? + # (handled in CheckForNonStandardConstructs for now) + # TODO(unknown): check that classes have DISALLOW_EVIL_CONSTRUCTORS + # (level 1 error) + pass + + # Check if people are using the verboten C basic types. The only exception + # we regularly allow is "unsigned short port" for port. + if Search(r'\bshort port\b', line): + if not Search(r'\bunsigned short port\b', line): + error(filename, linenum, 'runtime/int', 4, + 'Use "unsigned short" for ports, not "short"') + else: + match = Search(r'\b(short|long(?! +double)|long long)\b', line) + if match: + error(filename, linenum, 'runtime/int', 4, + 'Use int16/int64/etc, rather than the C type %s' % match.group(1)) + + # When snprintf is used, the second argument shouldn't be a literal. + match = Search(r'snprintf\s*\(([^,]*),\s*([0-9]*)\s*,', line) + if match and match.group(2) != '0': + # If 2nd arg is zero, snprintf is used to calculate size. + error(filename, linenum, 'runtime/printf', 3, + 'If you can, use sizeof(%s) instead of %s as the 2nd arg ' + 'to snprintf.' % (match.group(1), match.group(2))) + + # Check if some verboten C functions are being used. + if Search(r'\bsprintf\b', line): + error(filename, linenum, 'runtime/printf', 5, + 'Never use sprintf. Use snprintf instead.') + match = Search(r'\b(strcpy|strcat)\b', line) + if match: + error(filename, linenum, 'runtime/printf', 4, + 'Almost always, snprintf is better than %s' % match.group(1)) + + if Search(r'\bsscanf\b', line): + error(filename, linenum, 'runtime/printf', 1, + 'sscanf can be ok, but is slow and can overflow buffers.') + + # Check if some verboten operator overloading is going on + # TODO(unknown): catch out-of-line unary operator&: + # class X {}; + # int operator&(const X& x) { return 42; } // unary operator& + # The trick is it's hard to tell apart from binary operator&: + # class Y { int operator&(const Y& x) { return 23; } }; // binary operator& + if Search(r'\boperator\s*&\s*\(\s*\)', line): + error(filename, linenum, 'runtime/operator', 4, + 'Unary operator& is dangerous. Do not use it.') + + # Check for suspicious usage of "if" like + # } if (a == b) { + if Search(r'\}\s*if\s*\(', line): + error(filename, linenum, 'readability/braces', 4, + 'Did you mean "else if"? If not, start a new line for "if".') + + # Check for potential format string bugs like printf(foo). + # We constrain the pattern not to pick things like DocidForPrintf(foo). + # Not perfect but it can catch printf(foo.c_str()) and printf(foo->c_str()) + # TODO(sugawarayu): Catch the following case. Need to change the calling + # convention of the whole function to process multiple line to handle it. + # printf( + # boy_this_is_a_really_long_variable_that_cannot_fit_on_the_prev_line); + printf_args = _GetTextInside(line, r'(?i)\b(string)?printf\s*\(') + if printf_args: + match = Match(r'([\w.\->()]+)$', printf_args) + if match: + function_name = re.search(r'\b((?:string)?printf)\s*\(', + line, re.I).group(1) + error(filename, linenum, 'runtime/printf', 4, + 'Potential format string bug. Do %s("%%s", %s) instead.' + % (function_name, match.group(1))) + + # Check for potential memset bugs like memset(buf, sizeof(buf), 0). + match = Search(r'memset\s*\(([^,]*),\s*([^,]*),\s*0\s*\)', line) + if match and not Match(r"^''|-?[0-9]+|0x[0-9A-Fa-f]$", match.group(2)): + error(filename, linenum, 'runtime/memset', 4, + 'Did you mean "memset(%s, 0, %s)"?' + % (match.group(1), match.group(2))) + + if Search(r'\busing namespace\b', line): + error(filename, linenum, 'build/namespaces', 5, + 'Do not use namespace using-directives. ' + 'Use using-declarations instead.') + + # Detect variable-length arrays. + match = Match(r'\s*(.+::)?(\w+) [a-z]\w*\[(.+)];', line) + if (match and match.group(2) != 'return' and match.group(2) != 'delete' and + match.group(3).find(']') == -1): + # Split the size using space and arithmetic operators as delimiters. + # If any of the resulting tokens are not compile time constants then + # report the error. + tokens = re.split(r'\s|\+|\-|\*|\/|<<|>>]', match.group(3)) + is_const = True + skip_next = False + for tok in tokens: + if skip_next: + skip_next = False + continue + + if Search(r'sizeof\(.+\)', tok): continue + if Search(r'arraysize\(\w+\)', tok): continue + + tok = tok.lstrip('(') + tok = tok.rstrip(')') + if not tok: continue + if Match(r'\d+', tok): continue + if Match(r'0[xX][0-9a-fA-F]+', tok): continue + if Match(r'k[A-Z0-9]\w*', tok): continue + if Match(r'(.+::)?k[A-Z0-9]\w*', tok): continue + if Match(r'(.+::)?[A-Z][A-Z0-9_]*', tok): continue + # A catch all for tricky sizeof cases, including 'sizeof expression', + # 'sizeof(*type)', 'sizeof(const type)', 'sizeof(struct StructName)' + # requires skipping the next token because we split on ' ' and '*'. + if tok.startswith('sizeof'): + skip_next = True + continue + is_const = False + break + if not is_const: + error(filename, linenum, 'runtime/arrays', 1, + 'Do not use variable-length arrays. Use an appropriately named ' + "('k' followed by CamelCase) compile-time constant for the size.") + + # If DISALLOW_EVIL_CONSTRUCTORS, DISALLOW_COPY_AND_ASSIGN, or + # DISALLOW_IMPLICIT_CONSTRUCTORS is present, then it should be the last thing + # in the class declaration. + match = Match( + (r'\s*' + r'(DISALLOW_(EVIL_CONSTRUCTORS|COPY_AND_ASSIGN|IMPLICIT_CONSTRUCTORS))' + r'\(.*\);$'), + line) + if match and linenum + 1 < clean_lines.NumLines(): + next_line = clean_lines.elided[linenum + 1] + # We allow some, but not all, declarations of variables to be present + # in the statement that defines the class. The [\w\*,\s]* fragment of + # the regular expression below allows users to declare instances of + # the class or pointers to instances, but not less common types such + # as function pointers or arrays. It's a tradeoff between allowing + # reasonable code and avoiding trying to parse more C++ using regexps. + if not Search(r'^\s*}[\w\*,\s]*;', next_line): + error(filename, linenum, 'readability/constructors', 3, + match.group(1) + ' should be the last thing in the class') + + # Check for use of unnamed namespaces in header files. Registration + # macros are typically OK, so we allow use of "namespace {" on lines + # that end with backslashes. + if (file_extension == 'h' + and Search(r'\bnamespace\s*{', line) + and line[-1] != '\\'): + error(filename, linenum, 'build/namespaces', 4, + 'Do not use unnamed namespaces in header files. See ' + 'http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Namespaces' + ' for more information.') + + +def CheckCStyleCast(filename, linenum, line, raw_line, cast_type, pattern, + error): + """Checks for a C-style cast by looking for the pattern. + + This also handles sizeof(type) warnings, due to similarity of content. + + Args: + filename: The name of the current file. + linenum: The number of the line to check. + line: The line of code to check. + raw_line: The raw line of code to check, with comments. + cast_type: The string for the C++ cast to recommend. This is either + reinterpret_cast, static_cast, or const_cast, depending. + pattern: The regular expression used to find C-style casts. + error: The function to call with any errors found. + + Returns: + True if an error was emitted. + False otherwise. + """ + match = Search(pattern, line) + if not match: + return False + + # e.g., sizeof(int) + sizeof_match = Match(r'.*sizeof\s*$', line[0:match.start(1) - 1]) + if sizeof_match: + error(filename, linenum, 'runtime/sizeof', 1, + 'Using sizeof(type). Use sizeof(varname) instead if possible') + return True + + remainder = line[match.end(0):] + + # The close paren is for function pointers as arguments to a function. + # eg, void foo(void (*bar)(int)); + # The semicolon check is a more basic function check; also possibly a + # function pointer typedef. + # eg, void foo(int); or void foo(int) const; + # The equals check is for function pointer assignment. + # eg, void *(*foo)(int) = ... + # The > is for MockCallback<...> ... + # + # Right now, this will only catch cases where there's a single argument, and + # it's unnamed. It should probably be expanded to check for multiple + # arguments with some unnamed. + function_match = Match(r'\s*(\)|=|(const)?\s*(;|\{|throw\(\)|>))', remainder) + if function_match: + if (not function_match.group(3) or + function_match.group(3) == ';' or + ('MockCallback<' not in raw_line and + '/*' not in raw_line)): + error(filename, linenum, 'readability/function', 3, + 'All parameters should be named in a function') + return True + + # At this point, all that should be left is actual casts. + error(filename, linenum, 'readability/casting', 4, + 'Using C-style cast. Use %s<%s>(...) instead' % + (cast_type, match.group(1))) + + return True + + +_HEADERS_CONTAINING_TEMPLATES = ( + ('', ('deque',)), + ('', ('unary_function', 'binary_function', + 'plus', 'minus', 'multiplies', 'divides', 'modulus', + 'negate', + 'equal_to', 'not_equal_to', 'greater', 'less', + 'greater_equal', 'less_equal', + 'logical_and', 'logical_or', 'logical_not', + 'unary_negate', 'not1', 'binary_negate', 'not2', + 'bind1st', 'bind2nd', + 'pointer_to_unary_function', + 'pointer_to_binary_function', + 'ptr_fun', + 'mem_fun_t', 'mem_fun', 'mem_fun1_t', 'mem_fun1_ref_t', + 'mem_fun_ref_t', + 'const_mem_fun_t', 'const_mem_fun1_t', + 'const_mem_fun_ref_t', 'const_mem_fun1_ref_t', + 'mem_fun_ref', + )), + ('', ('numeric_limits',)), + ('', ('list',)), + ('', ('map', 'multimap',)), + ('', ('allocator',)), + ('', ('queue', 'priority_queue',)), + ('', ('set', 'multiset',)), + ('', ('stack',)), + ('', ('char_traits', 'basic_string',)), + ('', ('pair',)), + ('', ('vector',)), + + # gcc extensions. + # Note: std::hash is their hash, ::hash is our hash + ('', ('hash_map', 'hash_multimap',)), + ('', ('hash_set', 'hash_multiset',)), + ('', ('slist',)), + ) + +_RE_PATTERN_STRING = re.compile(r'\bstring\b') + +_re_pattern_algorithm_header = [] +for _template in ('copy', 'max', 'min', 'min_element', 'sort', 'swap', + 'transform'): + # Match max(..., ...), max(..., ...), but not foo->max, foo.max or + # type::max(). + _re_pattern_algorithm_header.append( + (re.compile(r'[^>.]\b' + _template + r'(<.*?>)?\([^\)]'), + _template, + '')) + +_re_pattern_templates = [] +for _header, _templates in _HEADERS_CONTAINING_TEMPLATES: + for _template in _templates: + _re_pattern_templates.append( + (re.compile(r'(\<|\b)' + _template + r'\s*\<'), + _template + '<>', + _header)) + + +def FilesBelongToSameModule(filename_cc, filename_h): + """Check if these two filenames belong to the same module. + + The concept of a 'module' here is a as follows: + foo.h, foo-inl.h, foo.cc, foo_test.cc and foo_unittest.cc belong to the + same 'module' if they are in the same directory. + some/path/public/xyzzy and some/path/internal/xyzzy are also considered + to belong to the same module here. + + If the filename_cc contains a longer path than the filename_h, for example, + '/absolute/path/to/base/sysinfo.cc', and this file would include + 'base/sysinfo.h', this function also produces the prefix needed to open the + header. This is used by the caller of this function to more robustly open the + header file. We don't have access to the real include paths in this context, + so we need this guesswork here. + + Known bugs: tools/base/bar.cc and base/bar.h belong to the same module + according to this implementation. Because of this, this function gives + some false positives. This should be sufficiently rare in practice. + + Args: + filename_cc: is the path for the .cc file + filename_h: is the path for the header path + + Returns: + Tuple with a bool and a string: + bool: True if filename_cc and filename_h belong to the same module. + string: the additional prefix needed to open the header file. + """ + + if not filename_cc.endswith('.cc'): + return (False, '') + filename_cc = filename_cc[:-len('.cc')] + if filename_cc.endswith('_unittest'): + filename_cc = filename_cc[:-len('_unittest')] + elif filename_cc.endswith('_test'): + filename_cc = filename_cc[:-len('_test')] + filename_cc = filename_cc.replace('/public/', '/') + filename_cc = filename_cc.replace('/internal/', '/') + + if not filename_h.endswith('.h'): + return (False, '') + filename_h = filename_h[:-len('.h')] + if filename_h.endswith('-inl'): + filename_h = filename_h[:-len('-inl')] + filename_h = filename_h.replace('/public/', '/') + filename_h = filename_h.replace('/internal/', '/') + + files_belong_to_same_module = filename_cc.endswith(filename_h) + common_path = '' + if files_belong_to_same_module: + common_path = filename_cc[:-len(filename_h)] + return files_belong_to_same_module, common_path + + +def UpdateIncludeState(filename, include_state, io=codecs): + """Fill up the include_state with new includes found from the file. + + Args: + filename: the name of the header to read. + include_state: an _IncludeState instance in which the headers are inserted. + io: The io factory to use to read the file. Provided for testability. + + Returns: + True if a header was succesfully added. False otherwise. + """ + headerfile = None + try: + headerfile = io.open(filename, 'r', 'utf8', 'replace') + except IOError: + return False + linenum = 0 + for line in headerfile: + linenum += 1 + clean_line = CleanseComments(line) + match = _RE_PATTERN_INCLUDE.search(clean_line) + if match: + include = match.group(2) + # The value formatting is cute, but not really used right now. + # What matters here is that the key is in include_state. + include_state.setdefault(include, '%s:%d' % (filename, linenum)) + return True + + +def CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error, + io=codecs): + """Reports for missing stl includes. + + This function will output warnings to make sure you are including the headers + necessary for the stl containers and functions that you use. We only give one + reason to include a header. For example, if you use both equal_to<> and + less<> in a .h file, only one (the latter in the file) of these will be + reported as a reason to include the . + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + include_state: An _IncludeState instance. + error: The function to call with any errors found. + io: The IO factory to use to read the header file. Provided for unittest + injection. + """ + required = {} # A map of header name to linenumber and the template entity. + # Example of required: { '': (1219, 'less<>') } + + for linenum in xrange(clean_lines.NumLines()): + line = clean_lines.elided[linenum] + if not line or line[0] == '#': + continue + + # String is special -- it is a non-templatized type in STL. + matched = _RE_PATTERN_STRING.search(line) + if matched: + # Don't warn about strings in non-STL namespaces: + # (We check only the first match per line; good enough.) + prefix = line[:matched.start()] + if prefix.endswith('std::') or not prefix.endswith('::'): + required[''] = (linenum, 'string') + + for pattern, template, header in _re_pattern_algorithm_header: + if pattern.search(line): + required[header] = (linenum, template) + + # The following function is just a speed up, no semantics are changed. + if not '<' in line: # Reduces the cpu time usage by skipping lines. + continue + + for pattern, template, header in _re_pattern_templates: + if pattern.search(line): + required[header] = (linenum, template) + + # The policy is that if you #include something in foo.h you don't need to + # include it again in foo.cc. Here, we will look at possible includes. + # Let's copy the include_state so it is only messed up within this function. + include_state = include_state.copy() + + # Did we find the header for this file (if any) and succesfully load it? + header_found = False + + # Use the absolute path so that matching works properly. + abs_filename = FileInfo(filename).FullName() + + # For Emacs's flymake. + # If cpplint is invoked from Emacs's flymake, a temporary file is generated + # by flymake and that file name might end with '_flymake.cc'. In that case, + # restore original file name here so that the corresponding header file can be + # found. + # e.g. If the file name is 'foo_flymake.cc', we should search for 'foo.h' + # instead of 'foo_flymake.h' + abs_filename = re.sub(r'_flymake\.cc$', '.cc', abs_filename) + + # include_state is modified during iteration, so we iterate over a copy of + # the keys. + header_keys = include_state.keys() + for header in header_keys: + (same_module, common_path) = FilesBelongToSameModule(abs_filename, header) + fullpath = common_path + header + if same_module and UpdateIncludeState(fullpath, include_state, io): + header_found = True + + # If we can't find the header file for a .cc, assume it's because we don't + # know where to look. In that case we'll give up as we're not sure they + # didn't include it in the .h file. + # TODO(unknown): Do a better job of finding .h files so we are confident that + # not having the .h file means there isn't one. + if filename.endswith('.cc') and not header_found: + return + + # All the lines have been processed, report the errors found. + for required_header_unstripped in required: + template = required[required_header_unstripped][1] + if required_header_unstripped.strip('<>"') not in include_state: + error(filename, required[required_header_unstripped][0], + 'build/include_what_you_use', 4, + 'Add #include ' + required_header_unstripped + ' for ' + template) + + +_RE_PATTERN_EXPLICIT_MAKEPAIR = re.compile(r'\bmake_pair\s*<') + + +def CheckMakePairUsesDeduction(filename, clean_lines, linenum, error): + """Check that make_pair's template arguments are deduced. + + G++ 4.6 in C++0x mode fails badly if make_pair's template arguments are + specified explicitly, and such use isn't intended in any case. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + raw = clean_lines.raw_lines + line = raw[linenum] + match = _RE_PATTERN_EXPLICIT_MAKEPAIR.search(line) + if match: + error(filename, linenum, 'build/explicit_make_pair', + 4, # 4 = high confidence + 'Omit template arguments from make_pair OR use pair directly OR' + ' if appropriate, construct a pair directly') + + +def ProcessLine(filename, file_extension, + clean_lines, line, include_state, function_state, + class_state, error, extra_check_functions=[]): + """Processes a single line in the file. + + Args: + filename: Filename of the file that is being processed. + file_extension: The extension (dot not included) of the file. + clean_lines: An array of strings, each representing a line of the file, + with comments stripped. + line: Number of line being processed. + include_state: An _IncludeState instance in which the headers are inserted. + function_state: A _FunctionState instance which counts function lines, etc. + class_state: A _ClassState instance which maintains information about + the current stack of nested class declarations being parsed. + error: A callable to which errors are reported, which takes 4 arguments: + filename, line number, error level, and message + extra_check_functions: An array of additional check functions that will be + run on each source line. Each function takes 4 + arguments: filename, clean_lines, line, error + """ + raw_lines = clean_lines.raw_lines + ParseNolintSuppressions(filename, raw_lines[line], line, error) + CheckForFunctionLengths(filename, clean_lines, line, function_state, error) + CheckForMultilineCommentsAndStrings(filename, clean_lines, line, error) + CheckStyle(filename, clean_lines, line, file_extension, class_state, error) + CheckLanguage(filename, clean_lines, line, file_extension, include_state, + error) + CheckForNonStandardConstructs(filename, clean_lines, line, + class_state, error) + CheckPosixThreading(filename, clean_lines, line, error) + CheckInvalidIncrement(filename, clean_lines, line, error) + CheckMakePairUsesDeduction(filename, clean_lines, line, error) + for check_fn in extra_check_functions: + check_fn(filename, clean_lines, line, error) + +def ProcessFileData(filename, file_extension, lines, error, + extra_check_functions=[]): + """Performs lint checks and reports any errors to the given error function. + + Args: + filename: Filename of the file that is being processed. + file_extension: The extension (dot not included) of the file. + lines: An array of strings, each representing a line of the file, with the + last element being empty if the file is terminated with a newline. + error: A callable to which errors are reported, which takes 4 arguments: + filename, line number, error level, and message + extra_check_functions: An array of additional check functions that will be + run on each source line. Each function takes 4 + arguments: filename, clean_lines, line, error + """ + lines = (['// marker so line numbers and indices both start at 1'] + lines + + ['// marker so line numbers end in a known way']) + + include_state = _IncludeState() + function_state = _FunctionState() + class_state = _ClassState() + + ResetNolintSuppressions() + + CheckForCopyright(filename, lines, error) + + if file_extension == 'h': + CheckForHeaderGuard(filename, lines, error) + + RemoveMultiLineComments(filename, lines, error) + clean_lines = CleansedLines(lines) + for line in xrange(clean_lines.NumLines()): + ProcessLine(filename, file_extension, clean_lines, line, + include_state, function_state, class_state, error, + extra_check_functions) + class_state.CheckFinished(filename, error) + + CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error) + + # We check here rather than inside ProcessLine so that we see raw + # lines rather than "cleaned" lines. + CheckForUnicodeReplacementCharacters(filename, lines, error) + + CheckForNewlineAtEOF(filename, lines, error) + +def ProcessFile(filename, vlevel, extra_check_functions=[]): + """Does google-lint on a single file. + + Args: + filename: The name of the file to parse. + + vlevel: The level of errors to report. Every error of confidence + >= verbose_level will be reported. 0 is a good default. + + extra_check_functions: An array of additional check functions that will be + run on each source line. Each function takes 4 + arguments: filename, clean_lines, line, error + """ + + _SetVerboseLevel(vlevel) + + try: + # Support the UNIX convention of using "-" for stdin. Note that + # we are not opening the file with universal newline support + # (which codecs doesn't support anyway), so the resulting lines do + # contain trailing '\r' characters if we are reading a file that + # has CRLF endings. + # If after the split a trailing '\r' is present, it is removed + # below. If it is not expected to be present (i.e. os.linesep != + # '\r\n' as in Windows), a warning is issued below if this file + # is processed. + + if filename == '-': + lines = codecs.StreamReaderWriter(sys.stdin, + codecs.getreader('utf8'), + codecs.getwriter('utf8'), + 'replace').read().split('\n') + else: + lines = codecs.open(filename, 'r', 'utf8', 'replace').read().split('\n') + + carriage_return_found = False + # Remove trailing '\r'. + for linenum in range(len(lines)): + if lines[linenum].endswith('\r'): + lines[linenum] = lines[linenum].rstrip('\r') + carriage_return_found = True + + except IOError: + sys.stderr.write( + "Skipping input '%s': Can't open for reading\n" % filename) + return + + # Note, if no dot is found, this will give the entire filename as the ext. + file_extension = filename[filename.rfind('.') + 1:] + + # When reading from stdin, the extension is unknown, so no cpplint tests + # should rely on the extension. + if (filename != '-' and file_extension != 'cc' and file_extension != 'h' + and file_extension != 'cpp'): + sys.stderr.write('Ignoring %s; not a .cc or .h file\n' % filename) + else: + ProcessFileData(filename, file_extension, lines, Error, + extra_check_functions) + if carriage_return_found and os.linesep != '\r\n': + # Use 0 for linenum since outputting only one error for potentially + # several lines. + Error(filename, 0, 'whitespace/newline', 1, + 'One or more unexpected \\r (^M) found;' + 'better to use only a \\n') + + sys.stderr.write('Done processing %s\n' % filename) + + +def PrintUsage(message): + """Prints a brief usage string and exits, optionally with an error message. + + Args: + message: The optional error message. + """ + sys.stderr.write(_USAGE) + if message: + sys.exit('\nFATAL ERROR: ' + message) + else: + sys.exit(1) + + +def PrintCategories(): + """Prints a list of all the error-categories used by error messages. + + These are the categories used to filter messages via --filter. + """ + sys.stderr.write(''.join(' %s\n' % cat for cat in _ERROR_CATEGORIES)) + sys.exit(0) + + +def ParseArguments(args): + """Parses the command line arguments. + + This may set the output format and verbosity level as side-effects. + + Args: + args: The command line arguments: + + Returns: + The list of filenames to lint. + """ + try: + (opts, filenames) = getopt.getopt(args, '', ['help', 'output=', 'verbose=', + 'counting=', + 'filter=']) + except getopt.GetoptError: + PrintUsage('Invalid arguments.') + + verbosity = _VerboseLevel() + output_format = _OutputFormat() + filters = '' + counting_style = '' + + for (opt, val) in opts: + if opt == '--help': + PrintUsage(None) + elif opt == '--output': + if not val in ('emacs', 'vs7'): + PrintUsage('The only allowed output formats are emacs and vs7.') + output_format = val + elif opt == '--verbose': + verbosity = int(val) + elif opt == '--filter': + filters = val + if not filters: + PrintCategories() + elif opt == '--counting': + if val not in ('total', 'toplevel', 'detailed'): + PrintUsage('Valid counting options are total, toplevel, and detailed') + counting_style = val + + if not filenames: + PrintUsage('No files were specified.') + + _SetOutputFormat(output_format) + _SetVerboseLevel(verbosity) + _SetFilters(filters) + _SetCountingStyle(counting_style) + + return filenames + + +def main(): + filenames = ParseArguments(sys.argv[1:]) + + # Change stderr to write with replacement characters so we don't die + # if we try to print something containing non-ASCII characters. + sys.stderr = codecs.StreamReaderWriter(sys.stderr, + codecs.getreader('utf8'), + codecs.getwriter('utf8'), + 'replace') + + _cpplint_state.ResetErrorCounts() + for filename in filenames: + ProcessFile(filename, _cpplint_state.verbose_level) + _cpplint_state.PrintErrorCounts() + + sys.exit(_cpplint_state.error_count > 0) + + +if __name__ == '__main__': + main() diff --git a/src/adapters/GmmxxAdapter.h b/src/adapters/GmmxxAdapter.h index 43a3f8302..3fd990a46 100644 --- a/src/adapters/GmmxxAdapter.h +++ b/src/adapters/GmmxxAdapter.h @@ -5,10 +5,10 @@ * Author: Christian Dehnert */ -#ifndef GMMXXADAPTER_H_ -#define GMMXXADAPTER_H_ +#ifndef STORM_ADAPTERS_GMMXXADAPTER_H_ +#define STORM_ADAPTERS_GMMXXADAPTER_H_ -#include "src/storage/SquareSparseMatrix.h" +#include "src/storage/SparseMatrix.h" #include "log4cplus/logger.h" #include "log4cplus/loggingmacros.h" @@ -26,74 +26,20 @@ public: * @return A pointer to a column-major sparse matrix in gmm++ format. */ template - static gmm::csr_matrix* toGmmxxSparseMatrix(storm::storage::SquareSparseMatrix const& matrix) { - uint_fast64_t realNonZeros = matrix.getNonZeroEntryCount() + matrix.getDiagonalNonZeroEntryCount(); + static gmm::csr_matrix* toGmmxxSparseMatrix(storm::storage::SparseMatrix const& matrix) { + uint_fast64_t realNonZeros = matrix.getNonZeroEntryCount(); LOG4CPLUS_DEBUG(logger, "Converting matrix with " << realNonZeros << " non-zeros to gmm++ format."); // Prepare the resulting matrix. gmm::csr_matrix* result = new gmm::csr_matrix(matrix.rowCount, matrix.rowCount); - // Reserve enough elements for the row indications. - result->jc.reserve(matrix.rowCount + 1); - - // For the column indications and the actual values, we have to gather - // the values in a temporary array first, as we have to integrate - // the values from the diagonal. For the row indications, we can just count the number of - // inserted diagonal elements and add it to the previous value. - uint_fast64_t* tmpColumnIndicationsArray = new uint_fast64_t[realNonZeros]; - T* tmpValueArray = new T[realNonZeros]; - T zero(0); - uint_fast64_t currentPosition = 0; - uint_fast64_t insertedDiagonalElements = 0; - for (uint_fast64_t i = 0; i < matrix.rowCount; ++i) { - // Compute correct start index of row. - result->jc[i] = matrix.rowIndications[i] + insertedDiagonalElements; - - // If the current row has no non-zero which is not on the diagonal, we have to check the - // diagonal element explicitly. - if (matrix.rowIndications[i + 1] - matrix.rowIndications[i] == 0) { - if (matrix.diagonalStorage[i] != zero) { - tmpColumnIndicationsArray[currentPosition] = i; - tmpValueArray[currentPosition] = matrix.diagonalStorage[i]; - ++currentPosition; ++insertedDiagonalElements; - } - } else { - // Otherwise, we can just enumerate the non-zeros which are not on the diagonal - // and fit in the diagonal element where appropriate. - bool includedDiagonal = false; - for (uint_fast64_t j = matrix.rowIndications[i]; j < matrix.rowIndications[i + 1]; ++j) { - if (matrix.diagonalStorage[i] != zero && !includedDiagonal && matrix.columnIndications[j] > i) { - includedDiagonal = true; - tmpColumnIndicationsArray[currentPosition] = i; - tmpValueArray[currentPosition] = matrix.diagonalStorage[i]; - ++currentPosition; ++insertedDiagonalElements; - } - tmpColumnIndicationsArray[currentPosition] = matrix.columnIndications[j]; - tmpValueArray[currentPosition] = matrix.valueStorage[j]; - ++currentPosition; - } - - // If the diagonal element is non-zero and was not inserted until now (i.e. all - // off-diagonal elements in the row are before the diagonal element. - if (!includedDiagonal && matrix.diagonalStorage[i] != zero) { - tmpColumnIndicationsArray[currentPosition] = i; - tmpValueArray[currentPosition] = matrix.diagonalStorage[i]; - ++currentPosition; ++insertedDiagonalElements; - } - } - } - // Fill in sentinel element at the end. - result->jc[matrix.rowCount] = static_cast(realNonZeros); - - // Now, we can copy the temporary array to the GMMXX format. + // Copy Row Indications + std::copy(matrix.rowIndications.begin(), matrix.rowIndications.end(), std::back_inserter(result->jc)); + // Copy Columns Indications result->ir.resize(realNonZeros); - std::copy(tmpColumnIndicationsArray, tmpColumnIndicationsArray + realNonZeros, result->ir.begin()); - delete[] tmpColumnIndicationsArray; - + std::copy(matrix.columnIndications.begin(), matrix.columnIndications.end(), std::back_inserter(result->ir)); // And do the same thing with the actual values. - result->pr.resize(realNonZeros); - std::copy(tmpValueArray, tmpValueArray + realNonZeros, result->pr.begin()); - delete[] tmpValueArray; + std::copy(matrix.valueStorage.begin(), matrix.valueStorage.end(), std::back_inserter(result->pr)); LOG4CPLUS_DEBUG(logger, "Done converting matrix to gmm++ format."); @@ -105,4 +51,4 @@ public: } //namespace storm -#endif /* GMMXXADAPTER_H_ */ +#endif /* STORM_ADAPTERS_GMMXXADAPTER_H_ */ diff --git a/src/adapters/IntermediateRepresentationAdapter.h b/src/adapters/IntermediateRepresentationAdapter.h index 01e5a9a17..de35d81a5 100644 --- a/src/adapters/IntermediateRepresentationAdapter.h +++ b/src/adapters/IntermediateRepresentationAdapter.h @@ -165,7 +165,8 @@ public: storm::storage::SquareSparseMatrix* resultMatrix = new storm::storage::SquareSparseMatrix(allStates.size()); resultMatrix->initialize(totalNumberOfTransitions); - for (StateType* state : allStates) { + uint_fast64_t currentIndex = 0; + for (StateType* currentState : allStates) { // Iterate over all modules. for (uint_fast64_t i = 0; i < program.getNumberOfModules(); ++i) { storm::ir::Module const& module = program.getModule(i); @@ -176,7 +177,7 @@ public: // Check if this command is enabled in the current state. if (command.getGuard()->getValueAsBool(*currentState)) { - std::unordered_map stateToProbabilityMap; + std::map stateIndexToProbabilityMap; for (uint_fast64_t k = 0; k < command.getNumberOfUpdates(); ++k) { storm::ir::Update const& update = command.getUpdate(k); @@ -191,15 +192,29 @@ public: setValue(newState, integerVariableToIndexMap[assignedVariable.first], assignedVariable.second.getExpression()->getValueAsInt(*currentState)); } - auto probIt = stateToProbabilityMap.find(newState); - if (probIt != stateToProbabilityMap.end()) { - stateToProbabilityMap[newState] += update.getLikelihoodExpression()->getValueAsDouble(*currentState); + uint_fast64_t targetIndex = (*stateToIndexMap.find(newState)).second; + delete newState; + + auto probIt = stateIndexToProbabilityMap.find(targetIndex); + if (probIt != stateIndexToProbabilityMap.end()) { + stateIndexToProbabilityMap[targetIndex] += update.getLikelihoodExpression()->getValueAsDouble(*currentState); } else { - ++totalNumberOfTransitions; - stateToProbabilityMap[newState] = update.getLikelihoodExpression()->getValueAsDouble(*currentState); + stateIndexToProbabilityMap[targetIndex] = update.getLikelihoodExpression()->getValueAsDouble(*currentState); } + } + + // Now insert the actual values into the matrix. + //for (auto targetIndex : stateIndexToProbabilityMap) { + // resultMatrix->addNextValue(currentIndex, targetIndex.first, targetIndex.second); + //} + } + } + } + ++currentIndex; } + resultMatrix->finalize(); + // Now free all the elements we allocated. for (auto element : allStates) { delete element; diff --git a/src/exceptions/BaseException.h b/src/exceptions/BaseException.h index d3f9308ff..c1b6ee054 100644 --- a/src/exceptions/BaseException.h +++ b/src/exceptions/BaseException.h @@ -8,13 +8,11 @@ namespace storm { namespace exceptions { template -class BaseException : public std::exception -{ +class BaseException : public std::exception { public: BaseException() : exception() {} BaseException(const BaseException& cp) - : exception(cp), stream(cp.stream.str()) - { + : exception(cp), stream(cp.stream.str()) { } BaseException(const char* cstr) { @@ -24,14 +22,12 @@ class BaseException : public std::exception ~BaseException() throw() { } template - E& operator<<(const T& var) - { + E& operator<<(const T& var) { this->stream << var; return * dynamic_cast(this); } - virtual const char* what() const throw() - { + virtual const char* what() const throw() { return this->stream.str().c_str(); } diff --git a/src/formula/BoundedEventually.h b/src/formula/BoundedEventually.h index 58973456d..0a66f6b83 100644 --- a/src/formula/BoundedEventually.h +++ b/src/formula/BoundedEventually.h @@ -120,7 +120,7 @@ public: BoundedEventually* result = new BoundedEventually(); result->setBound(bound); if (child != nullptr) { - result->setRight(child->clone()); + result->setChild(child->clone()); } return result; } diff --git a/src/modelChecker/EigenDtmcPrctlModelChecker.h b/src/modelChecker/EigenDtmcPrctlModelChecker.h index 3c387bec5..e6ea58725 100644 --- a/src/modelChecker/EigenDtmcPrctlModelChecker.h +++ b/src/modelChecker/EigenDtmcPrctlModelChecker.h @@ -48,7 +48,7 @@ public: storm::storage::BitVector* rightStates = this->checkStateFormula(formula.getRight()); // Copy the matrix before we make any changes. - storm::storage::SquareSparseMatrix tmpMatrix(*this->getModel().getTransitionProbabilityMatrix()); + storm::storage::SparseMatrix tmpMatrix(*this->getModel().getTransitionProbabilityMatrix()); // Make all rows absorbing that violate both sub-formulas or satisfy the second sub-formula. tmpMatrix.makeRowsAbsorbing((~*leftStates | *rightStates) | *rightStates); @@ -148,7 +148,7 @@ public: typedef Eigen::Map MapType; // Now we can eliminate the rows and columns from the original transition probability matrix. - storm::storage::SquareSparseMatrix* submatrix = this->getModel().getTransitionProbabilityMatrix()->getSubmatrix(maybeStates); + storm::storage::SparseMatrix* submatrix = this->getModel().getTransitionProbabilityMatrix()->getSubmatrix(maybeStates); // Converting the matrix to the form needed for the equation system. That is, we go from // x = A*x + b to (I-A)x = b. submatrix->convertToEquationSystem(); diff --git a/src/modelChecker/GmmxxDtmcPrctlModelChecker.h b/src/modelChecker/GmmxxDtmcPrctlModelChecker.h index 2d4d65d49..c90efc3b0 100644 --- a/src/modelChecker/GmmxxDtmcPrctlModelChecker.h +++ b/src/modelChecker/GmmxxDtmcPrctlModelChecker.h @@ -48,7 +48,7 @@ public: storm::storage::BitVector* rightStates = this->checkStateFormula(formula.getRight()); // Copy the matrix before we make any changes. - storm::storage::SquareSparseMatrix tmpMatrix(*this->getModel().getTransitionProbabilityMatrix()); + storm::storage::SparseMatrix tmpMatrix(*this->getModel().getTransitionProbabilityMatrix()); // Make all rows absorbing that violate both sub-formulas or satisfy the second sub-formula. tmpMatrix.makeRowsAbsorbing(~(*leftStates | *rightStates) | *rightStates); @@ -130,7 +130,7 @@ public: // Only try to solve system if there are states for which the probability is unknown. if (maybeStates.getNumberOfSetBits() > 0) { // Now we can eliminate the rows and columns from the original transition probability matrix. - storm::storage::SquareSparseMatrix* submatrix = this->getModel().getTransitionProbabilityMatrix()->getSubmatrix(maybeStates); + storm::storage::SparseMatrix* submatrix = this->getModel().getTransitionProbabilityMatrix()->getSubmatrix(maybeStates); // Converting the matrix from the fixpoint notation to the form needed for the equation // system. That is, we go from x = A*x + b to (I-A)x = b. submatrix->convertToEquationSystem(); @@ -261,7 +261,7 @@ public: storm::storage::BitVector maybeStates = ~(*targetStates) & ~infinityStates; if (maybeStates.getNumberOfSetBits() > 0) { // Now we can eliminate the rows and columns from the original transition probability matrix. - storm::storage::SquareSparseMatrix* submatrix = this->getModel().getTransitionProbabilityMatrix()->getSubmatrix(maybeStates); + storm::storage::SparseMatrix* submatrix = this->getModel().getTransitionProbabilityMatrix()->getSubmatrix(maybeStates); // Converting the matrix from the fixpoint notation to the form needed for the equation // system. That is, we go from x = A*x + b to (I-A)x = b. submatrix->convertToEquationSystem(); diff --git a/src/models/AbstractModel.cpp b/src/models/AbstractModel.cpp new file mode 100644 index 000000000..70acd4e3a --- /dev/null +++ b/src/models/AbstractModel.cpp @@ -0,0 +1,25 @@ +#include "src/models/AbstractModel.h" + +#include + +/*! + * This method will output the name of the model type or "Unknown". + * If something went terribly wrong, i.e. if type does not contain any value + * that is valid for a ModelType or some value of the enum was not + * implemented here, it will output "Invalid ModelType". + * + * @param os Output stream. + * @param type Model type. + * @return Output stream os. + */ +std::ostream& storm::models::operator<<(std::ostream& os, storm::models::ModelType const type) { + switch (type) { + case storm::models::Unknown: os << "Unknown"; break; + case storm::models::DTMC: os << "DTMC"; break; + case storm::models::CTMC: os << "CTMC"; break; + case storm::models::MDP: os << "MDP"; break; + case storm::models::CTMDP: os << "CTMDP"; break; + default: os << "Invalid ModelType"; break; + } + return os; +} diff --git a/src/models/AbstractModel.h b/src/models/AbstractModel.h new file mode 100644 index 000000000..bb223b718 --- /dev/null +++ b/src/models/AbstractModel.h @@ -0,0 +1,64 @@ +#ifndef STORM_MODELS_ABSTRACTMODEL_H_ +#define STORM_MODELS_ABSTRACTMODEL_H_ + +#include + +namespace storm { +namespace models { + +/*! + * @brief Enumeration of all supported types of models. + */ +enum ModelType { + Unknown, DTMC, CTMC, MDP, CTMDP +}; + +/*! + * @brief Stream output operator for ModelType. + */ +std::ostream& operator<<(std::ostream& os, ModelType const type); + +/*! + * @brief Base class for all model classes. + * + * This is base class defines a common interface for all models to identify + * their type and obtain the special model. + */ +class AbstractModel { + + public: + /*! + * @brief Casts the model to the model type that was actually + * created. + * + * As all methods that work on generic models will use this + * AbstractModel class, this method provides a convenient way to + * cast an AbstractModel object to an object of a concrete model + * type, which can be obtained via getType(). The mapping from an + * element of the ModelType enum to the actual class must be done + * by the caller. + * + * This methods uses std::dynamic_pointer_cast internally. + * + * @return Shared pointer of new type to this object. + */ + template + std::shared_ptr as() { + return std::dynamic_pointer_cast(std::shared_ptr(this)); + } + + /*! + * @brief Return the actual type of the model. + * + * Each model must implement this method. + * + * @return Type of the model. + */ + virtual ModelType getType() = 0; + +}; + +} // namespace models +} // namespace storm + +#endif /* STORM_MODELS_ABSTRACTMODEL_H_ */ diff --git a/src/models/Ctmc.h b/src/models/Ctmc.h index 7f21528f0..263a62fe9 100644 --- a/src/models/Ctmc.h +++ b/src/models/Ctmc.h @@ -15,9 +15,11 @@ #include "AtomicPropositionsLabeling.h" #include "GraphTransitions.h" -#include "src/storage/SquareSparseMatrix.h" +#include "src/storage/SparseMatrix.h" #include "src/exceptions/InvalidArgumentException.h" +#include "src/models/AbstractModel.h" + namespace storm { namespace models { @@ -27,7 +29,7 @@ namespace models { * labeled with atomic propositions. */ template -class Ctmc { +class Ctmc : public storm::models::AbstractModel { public: //! Constructor @@ -39,10 +41,10 @@ public: * @param stateLabeling The labeling that assigns a set of atomic * propositions to each state. */ - Ctmc(std::shared_ptr> rateMatrix, + Ctmc(std::shared_ptr> rateMatrix, std::shared_ptr stateLabeling, std::shared_ptr> stateRewards = nullptr, - std::shared_ptr> transitionRewardMatrix = nullptr) + std::shared_ptr> transitionRewardMatrix = nullptr) : rateMatrix(rateMatrix), stateLabeling(stateLabeling), stateRewards(stateRewards), transitionRewardMatrix(transitionRewardMatrix), backwardTransitions(nullptr) { @@ -56,7 +58,7 @@ public: Ctmc(const Ctmc &ctmc) : rateMatrix(ctmc.rateMatrix), stateLabeling(ctmc.stateLabeling), stateRewards(ctmc.stateRewards), transitionRewardMatrix(ctmc.transitionRewardMatrix) { - if (ctmc.backardTransitions != nullptr) { + if (ctmc.backwardTransitions != nullptr) { this->backwardTransitions = new storm::models::GraphTransitions(*ctmc.backwardTransitions); } } @@ -104,7 +106,7 @@ public: * @return A pointer to the matrix representing the transition probability * function. */ - std::shared_ptr> getTransitionRateMatrix() const { + std::shared_ptr> getTransitionRateMatrix() const { return this->rateMatrix; } @@ -112,7 +114,7 @@ public: * Returns a pointer to the matrix representing the transition rewards. * @return A pointer to the matrix representing the transition rewards. */ - std::shared_ptr> getTransitionRewardMatrix() const { + std::shared_ptr> getTransitionRewardMatrix() const { return this->transitionRewardMatrix; } @@ -161,10 +163,14 @@ public: << std::endl; } + storm::models::ModelType getType() { + return CTMC; + } + private: /*! A matrix representing the transition rate function of the CTMC. */ - std::shared_ptr> rateMatrix; + std::shared_ptr> rateMatrix; /*! The labeling of the states of the CTMC. */ std::shared_ptr stateLabeling; @@ -173,7 +179,7 @@ private: std::shared_ptr> stateRewards; /*! The transition-based rewards of the CTMC. */ - std::shared_ptr> transitionRewardMatrix; + std::shared_ptr> transitionRewardMatrix; /*! * A data structure that stores the predecessors for all states. This is diff --git a/src/models/Dtmc.h b/src/models/Dtmc.h index cb75afc28..c766da8cb 100644 --- a/src/models/Dtmc.h +++ b/src/models/Dtmc.h @@ -15,9 +15,11 @@ #include "AtomicPropositionsLabeling.h" #include "GraphTransitions.h" -#include "src/storage/SquareSparseMatrix.h" +#include "src/storage/SparseMatrix.h" #include "src/exceptions/InvalidArgumentException.h" #include "src/utility/CommandLine.h" +#include "src/utility/Settings.h" +#include "src/models/AbstractModel.h" namespace storm { @@ -28,7 +30,7 @@ namespace models { * labeled with atomic propositions. */ template -class Dtmc { +class Dtmc : public storm::models::AbstractModel { public: //! Constructor @@ -40,15 +42,16 @@ public: * @param stateLabeling The labeling that assigns a set of atomic * propositions to each state. */ - Dtmc(std::shared_ptr> probabilityMatrix, + Dtmc(std::shared_ptr> probabilityMatrix, std::shared_ptr stateLabeling, std::shared_ptr> stateRewards = nullptr, - std::shared_ptr> transitionRewardMatrix = nullptr) + std::shared_ptr> transitionRewardMatrix = nullptr) : probabilityMatrix(probabilityMatrix), stateLabeling(stateLabeling), stateRewards(stateRewards), transitionRewardMatrix(transitionRewardMatrix), backwardTransitions(nullptr) { - if (!this->checkValidityProbabilityMatrix()) { - std::cerr << "Probability matrix is invalid" << std::endl; + if (!this->checkValidityOfProbabilityMatrix()) { + LOG4CPLUS_ERROR(logger, "Probability matrix is invalid."); + throw storm::exceptions::InvalidArgumentException() << "Probability matrix is invalid."; } } @@ -60,11 +63,12 @@ public: Dtmc(const Dtmc &dtmc) : probabilityMatrix(dtmc.probabilityMatrix), stateLabeling(dtmc.stateLabeling), stateRewards(dtmc.stateRewards), transitionRewardMatrix(dtmc.transitionRewardMatrix) { - if (dtmc.backardTransitions != nullptr) { + if (dtmc.backwardTransitions != nullptr) { this->backwardTransitions = new storm::models::GraphTransitions(*dtmc.backwardTransitions); } - if (!this->checkValidityProbabilityMatrix()) { - std::cerr << "Probability matrix is invalid" << std::endl; + if (!this->checkValidityOfProbabilityMatrix()) { + LOG4CPLUS_ERROR(logger, "Probability matrix is invalid."); + throw storm::exceptions::InvalidArgumentException() << "Probability matrix is invalid."; } } @@ -111,7 +115,7 @@ public: * @return A pointer to the matrix representing the transition probability * function. */ - std::shared_ptr> getTransitionProbabilityMatrix() const { + std::shared_ptr> getTransitionProbabilityMatrix() const { return this->probabilityMatrix; } @@ -119,7 +123,7 @@ public: * Returns a pointer to the matrix representing the transition rewards. * @return A pointer to the matrix representing the transition rewards. */ - std::shared_ptr> getTransitionRewardMatrix() const { + std::shared_ptr> getTransitionRewardMatrix() const { return this->transitionRewardMatrix; } @@ -192,6 +196,10 @@ public: out << std::endl; storm::utility::printSeparationLine(out); } + + storm::models::ModelType getType() { + return DTMC; + } private: @@ -200,17 +208,26 @@ private: * * Checks probability matrix if all rows sum up to one. */ - bool checkValidityProbabilityMatrix() { + bool checkValidityOfProbabilityMatrix() { + // Get the settings object to customize linear solving. + storm::settings::Settings* s = storm::settings::instance(); + double precision = s->get("precision"); + + if (this->probabilityMatrix->getRowCount() != this->probabilityMatrix->getColumnCount()) { + // not square + return false; + } + for (uint_fast64_t row = 0; row < this->probabilityMatrix->getRowCount(); row++) { T sum = this->probabilityMatrix->getRowSum(row); - if (sum == 0) continue; - if (std::abs(sum - 1) > 1e-10) return false; + if (sum == 0) return false; + if (std::abs(sum - 1) > precision) return false; } return true; } /*! A matrix representing the transition probability function of the DTMC. */ - std::shared_ptr> probabilityMatrix; + std::shared_ptr> probabilityMatrix; /*! The labeling of the states of the DTMC. */ std::shared_ptr stateLabeling; @@ -219,7 +236,7 @@ private: std::shared_ptr> stateRewards; /*! The transition-based rewards of the DTMC. */ - std::shared_ptr> transitionRewardMatrix; + std::shared_ptr> transitionRewardMatrix; /*! * A data structure that stores the predecessors for all states. This is diff --git a/src/models/GraphTransitions.h b/src/models/GraphTransitions.h index 3ccc6bed5..576d99f8c 100644 --- a/src/models/GraphTransitions.h +++ b/src/models/GraphTransitions.h @@ -8,7 +8,7 @@ #ifndef STORM_MODELS_GRAPHTRANSITIONS_H_ #define STORM_MODELS_GRAPHTRANSITIONS_H_ -#include "src/storage/SquareSparseMatrix.h" +#include "src/storage/SparseMatrix.h" #include #include @@ -39,7 +39,7 @@ public: * @param forward If set to true, this objects will store the graph structure * of the backwards transition relation. */ - GraphTransitions(std::shared_ptr> transitionMatrix, bool forward) + GraphTransitions(std::shared_ptr> transitionMatrix, bool forward) : successorList(nullptr), stateIndications(nullptr), numberOfStates(transitionMatrix->getRowCount()), numberOfNonZeroTransitions(transitionMatrix->getNonZeroEntryCount()) { if (forward) { this->initializeForward(transitionMatrix); @@ -87,18 +87,18 @@ private: * Initializes this graph transitions object using the forward transition * relation given by means of a sparse matrix. */ - void initializeForward(std::shared_ptr> transitionMatrix) { + void initializeForward(std::shared_ptr> transitionMatrix) { this->successorList = new uint_fast64_t[numberOfNonZeroTransitions]; this->stateIndications = new uint_fast64_t[numberOfStates + 1]; // First, we copy the index list from the sparse matrix as this will // stay the same. - std::copy(transitionMatrix->getRowIndicationsPointer(), transitionMatrix->getRowIndicationsPointer() + numberOfStates + 1, this->stateIndications); + std::copy(transitionMatrix->getRowIndicationsPointer().begin(), transitionMatrix->getRowIndicationsPointer().end(), this->stateIndications); // Now we can iterate over all rows of the transition matrix and record // the target state. for (uint_fast64_t i = 0, currentNonZeroElement = 0; i < numberOfStates; i++) { - for (auto rowIt = transitionMatrix->beginConstColumnNoDiagIterator(i); rowIt != transitionMatrix->endConstColumnNoDiagIterator(i); ++rowIt) { + for (auto rowIt = transitionMatrix->beginConstColumnIterator(i); rowIt != transitionMatrix->endConstColumnIterator(i); ++rowIt) { this->stateIndications[currentNonZeroElement++] = *rowIt; } } @@ -109,7 +109,7 @@ private: * relation, whose forward transition relation is given by means of a sparse * matrix. */ - void initializeBackward(std::shared_ptr> transitionMatrix) { + void initializeBackward(std::shared_ptr> transitionMatrix) { this->successorList = new uint_fast64_t[numberOfNonZeroTransitions](); this->stateIndications = new uint_fast64_t[numberOfStates + 1](); @@ -117,7 +117,7 @@ private: // NOTE: We disregard the diagonal here, as we only consider "true" // predecessors. for (uint_fast64_t i = 0; i < numberOfStates; i++) { - for (auto rowIt = transitionMatrix->beginConstColumnNoDiagIterator(i); rowIt != transitionMatrix->endConstColumnNoDiagIterator(i); ++rowIt) { + for (auto rowIt = transitionMatrix->beginConstColumnIterator(i); rowIt != transitionMatrix->endConstColumnIterator(i); ++rowIt) { this->stateIndications[*rowIt + 1]++; } } @@ -140,7 +140,7 @@ private: // Now we are ready to actually fill in the list of predecessors for // every state. Again, we start by considering all but the last row. for (uint_fast64_t i = 0; i < numberOfStates; i++) { - for (auto rowIt = transitionMatrix->beginConstColumnNoDiagIterator(i); rowIt != transitionMatrix->endConstColumnNoDiagIterator(i); ++rowIt) { + for (auto rowIt = transitionMatrix->beginConstColumnIterator(i); rowIt != transitionMatrix->endConstColumnIterator(i); ++rowIt) { this->successorList[nextIndicesList[*rowIt]++] = i; } } diff --git a/src/models/Mdp.h b/src/models/Mdp.h new file mode 100644 index 000000000..dc58782a2 --- /dev/null +++ b/src/models/Mdp.h @@ -0,0 +1,246 @@ +/* + * Mdp.h + * + * Created on: 14.01.2013 + * Author: Philipp Berger + */ + +#ifndef STORM_MODELS_MDP_H_ +#define STORM_MODELS_MDP_H_ + +#include +#include +#include +#include + +#include "AtomicPropositionsLabeling.h" +#include "GraphTransitions.h" +#include "src/storage/SparseMatrix.h" +#include "src/exceptions/InvalidArgumentException.h" +#include "src/utility/CommandLine.h" +#include "src/utility/Settings.h" +#include "src/models/AbstractModel.h" + +namespace storm { + +namespace models { + +/*! + * This class represents a Markov Decision Process (MDP) whose states are + * labeled with atomic propositions. + */ +template +class Mdp : public storm::models::AbstractModel { + +public: + //! Constructor + /*! + * Constructs a MDP object from the given transition probability matrix and + * the given labeling of the states. + * @param probabilityMatrix The transition probability relation of the + * MDP given by a matrix. + * @param stateLabeling The labeling that assigns a set of atomic + * propositions to each state. + */ + Mdp(std::shared_ptr> probabilityMatrix, + std::shared_ptr stateLabeling, + std::shared_ptr> stateRewards = nullptr, + std::shared_ptr> transitionRewardMatrix = nullptr) + : probabilityMatrix(probabilityMatrix), stateLabeling(stateLabeling), + stateRewards(stateRewards), transitionRewardMatrix(transitionRewardMatrix), + backwardTransitions(nullptr) { + if (!this->checkValidityOfProbabilityMatrix()) { + LOG4CPLUS_ERROR(logger, "Probability matrix is invalid."); + throw storm::exceptions::InvalidArgumentException() << "Probability matrix is invalid."; + } + } + + //! Copy Constructor + /*! + * Copy Constructor. Performs a deep copy of the given MDP. + * @param mdp A reference to the MDP that is to be copied. + */ + Mdp(const Mdp &mdp) : probabilityMatrix(mdp.probabilityMatrix), + stateLabeling(mdp.stateLabeling), stateRewards(mdp.stateRewards), + transitionRewardMatrix(mdp.transitionRewardMatrix) { + if (mdp.backwardTransitions != nullptr) { + this->backwardTransitions = new storm::models::GraphTransitions(*mdp.backwardTransitions); + } + if (!this->checkValidityOfProbabilityMatrix()) { + LOG4CPLUS_ERROR(logger, "Probability matrix is invalid."); + throw storm::exceptions::InvalidArgumentException() << "Probability matrix is invalid."; + } + } + + //! Destructor + /*! + * Destructor. Frees the matrix and labeling associated with this MDP. + */ + ~Mdp() { + if (this->backwardTransitions != nullptr) { + delete this->backwardTransitions; + } + } + + /*! + * Returns the state space size of the MDP. + * @return The size of the state space of the MDP. + */ + uint_fast64_t getNumberOfStates() const { + return this->probabilityMatrix->getColumnCount(); + } + + /*! + * Returns the number of (non-zero) transitions of the MDP. + * @return The number of (non-zero) transitions of the MDP. + */ + uint_fast64_t getNumberOfTransitions() const { + return this->probabilityMatrix->getNonZeroEntryCount(); + } + + /*! + * Returns a bit vector in which exactly those bits are set to true that + * correspond to a state labeled with the given atomic proposition. + * @param ap The atomic proposition for which to get the bit vector. + * @return A bit vector in which exactly those bits are set to true that + * correspond to a state labeled with the given atomic proposition. + */ + storm::storage::BitVector* getLabeledStates(std::string ap) const { + return this->stateLabeling->getAtomicProposition(ap); + } + + /*! + * Returns a pointer to the matrix representing the transition probability + * function. + * @return A pointer to the matrix representing the transition probability + * function. + */ + std::shared_ptr> getTransitionProbabilityMatrix() const { + return this->probabilityMatrix; + } + + /*! + * Returns a pointer to the matrix representing the transition rewards. + * @return A pointer to the matrix representing the transition rewards. + */ + std::shared_ptr> getTransitionRewardMatrix() const { + return this->transitionRewardMatrix; + } + + /*! + * Returns a pointer to the vector representing the state rewards. + * @return A pointer to the vector representing the state rewards. + */ + std::shared_ptr> getStateRewards() const { + return this->stateRewards; + } + + /*! + * + */ + std::set const getPropositionsForState(uint_fast64_t const &state) const { + return stateLabeling->getPropositionsForState(state); + } + + /*! + * Retrieves a reference to the backwards transition relation. + * @return A reference to the backwards transition relation. + */ + storm::models::GraphTransitions& getBackwardTransitions() { + if (this->backwardTransitions == nullptr) { + this->backwardTransitions = new storm::models::GraphTransitions(this->probabilityMatrix, false); + } + return *this->backwardTransitions; + } + + /*! + * Retrieves whether this MDP has a state reward model. + * @return True if this MDP has a state reward model. + */ + bool hasStateRewards() { + return this->stateRewards != nullptr; + } + + /*! + * Retrieves whether this MDP has a transition reward model. + * @return True if this MDP has a transition reward model. + */ + bool hasTransitionRewards() { + return this->transitionRewardMatrix != nullptr; + } + + /*! + * Retrieves whether the given atomic proposition is a valid atomic proposition in this model. + * @param atomicProposition The atomic proposition to be checked for validity. + * @return True if the given atomic proposition is valid in this model. + */ + bool hasAtomicProposition(std::string const& atomicProposition) { + return this->stateLabeling->containsAtomicProposition(atomicProposition); + } + + /*! + * Prints information about the model to the specified stream. + * @param out The stream the information is to be printed to. + */ + void printModelInformationToStream(std::ostream& out) const { + storm::utility::printSeparationLine(out); + out << std::endl; + out << "Model type: \t\tMDP" << std::endl; + out << "States: \t\t" << this->getNumberOfStates() << std::endl; + out << "Transitions: \t\t" << this->getNumberOfTransitions() << std::endl; + this->stateLabeling->printAtomicPropositionsInformationToStream(out); + out << "Size in memory: \t" + << (this->probabilityMatrix->getSizeInMemory() + + this->stateLabeling->getSizeInMemory() + + sizeof(*this))/1024 << " kbytes" << std::endl; + out << std::endl; + storm::utility::printSeparationLine(out); + } + + storm::models::ModelType getType() { + return MDP; + } + +private: + + /*! + * @brief Perform some sanity checks. + * + * Checks probability matrix if all rows sum up to one. + */ + bool checkValidityOfProbabilityMatrix() { + // Get the settings object to customize linear solving. + storm::settings::Settings* s = storm::settings::instance(); + double precision = s->get("precision"); + for (uint_fast64_t row = 0; row < this->probabilityMatrix->getRowCount(); row++) { + T sum = this->probabilityMatrix->getRowSum(row); + if (sum == 0) continue; + if (std::abs(sum - 1) > precision) return false; + } + return true; + } + + /*! A matrix representing the transition probability function of the MDP. */ + std::shared_ptr> probabilityMatrix; + + /*! The labeling of the states of the MDP. */ + std::shared_ptr stateLabeling; + + /*! The state-based rewards of the MDP. */ + std::shared_ptr> stateRewards; + + /*! The transition-based rewards of the MDP. */ + std::shared_ptr> transitionRewardMatrix; + + /*! + * A data structure that stores the predecessors for all states. This is + * needed for backwards directed searches. + */ + storm::models::GraphTransitions* backwardTransitions; +}; + +} // namespace models + +} // namespace storm + +#endif /* STORM_MODELS_MDP_H_ */ diff --git a/src/parser/AtomicPropositionLabelingParser.cpp b/src/parser/AtomicPropositionLabelingParser.cpp index 07dc2d344..c2e6774f8 100644 --- a/src/parser/AtomicPropositionLabelingParser.cpp +++ b/src/parser/AtomicPropositionLabelingParser.cpp @@ -7,20 +7,22 @@ #include "src/parser/AtomicPropositionLabelingParser.h" -#include "src/exceptions/WrongFileFormatException.h" -#include "src/exceptions/FileIoException.h" +#include +#include +#include +#include +#include -#include "src/utility/OsDetection.h" #include #include #include +#include #include #include -#include -#include -#include -#include -#include + +#include "src/exceptions/WrongFileFormatException.h" +#include "src/exceptions/FileIoException.h" +#include "src/utility/OsDetection.h" #include "log4cplus/logger.h" #include "log4cplus/loggingmacros.h" @@ -39,17 +41,16 @@ namespace parser { * @return The pointer to the created labeling object. */ AtomicPropositionLabelingParser::AtomicPropositionLabelingParser(uint_fast64_t node_count, - std::string const & filename) - : labeling(nullptr) -{ + std::string const & filename) + : labeling(nullptr) { /* - * open file + * Open file. */ MappedFile file(filename.c_str()); char* buf = file.data; /* - * first run: obtain number of propositions + * First run: obtain number of propositions. */ char separator[] = " \r\n\t"; bool foundDecl = false, foundEnd = false; @@ -57,11 +58,11 @@ AtomicPropositionLabelingParser::AtomicPropositionLabelingParser(uint_fast64_t n { size_t cnt = 0; /* - * iterate over tokens until we hit #END or end of file + * Iterate over tokens until we hit #END or end of file. */ while(buf[0] != '\0') { buf += cnt; - cnt = strcspn(buf, separator); // position of next separator + cnt = strcspn(buf, separator); // position of next separator if (cnt > 0) { /* * next token is #DECLARATION: just skip it @@ -71,31 +72,32 @@ AtomicPropositionLabelingParser::AtomicPropositionLabelingParser(uint_fast64_t n if (strncmp(buf, "#DECLARATION", cnt) == 0) { foundDecl = true; continue; - } - else if (strncmp(buf, "#END", cnt) == 0) { + } else if (strncmp(buf, "#END", cnt) == 0) { foundEnd = true; break; } proposition_count++; - } else buf++; // next char is separator, one step forward + } else { + buf++; // next char is separator, one step forward + } } - + /* * If #DECLARATION or #END were not found, the file format is wrong */ - if (! (foundDecl && foundEnd)) { + if (!(foundDecl && foundEnd)) { LOG4CPLUS_ERROR(logger, "Wrong file format in (" << filename << "). File header is corrupted."); - if (! foundDecl) LOG4CPLUS_ERROR(logger, "\tDid not find #DECLARATION token."); - if (! foundEnd) LOG4CPLUS_ERROR(logger, "\tDid not find #END token."); + if (!foundDecl) LOG4CPLUS_ERROR(logger, "\tDid not find #DECLARATION token."); + if (!foundEnd) LOG4CPLUS_ERROR(logger, "\tDid not find #END token."); throw storm::exceptions::WrongFileFormatException(); } } - + /* * create labeling object with given node and proposition count */ this->labeling = std::shared_ptr(new storm::models::AtomicPropositionsLabeling(node_count, proposition_count)); - + /* * second run: add propositions and node labels to labeling * @@ -107,11 +109,11 @@ AtomicPropositionLabelingParser::AtomicPropositionLabelingParser(uint_fast64_t n * load propositions * As we already checked the file format, we can be a bit sloppy here... */ - char proposition[128]; // buffer for proposition names + char proposition[128]; // buffer for proposition names size_t cnt = 0; do { buf += cnt; - cnt = strcspn(buf, separator); // position of next separator + cnt = strcspn(buf, separator); // position of next separator if (cnt >= sizeof(proposition)) { /* * if token is longer than our buffer, the following strncpy code might get risky... @@ -129,7 +131,9 @@ AtomicPropositionLabelingParser::AtomicPropositionLabelingParser(uint_fast64_t n strncpy(proposition, buf, cnt); proposition[cnt] = '\0'; this->labeling->addAtomicProposition(proposition); - } else cnt = 1; // next char is separator, one step forward + } else { + cnt = 1; // next char is separator, one step forward + } } while (cnt > 0); /* * Right here, the buf pointer is still pointing to our last token, @@ -137,7 +141,7 @@ AtomicPropositionLabelingParser::AtomicPropositionLabelingParser(uint_fast64_t n */ buf += 4; } - + { /* * now parse node label assignments @@ -178,5 +182,5 @@ AtomicPropositionLabelingParser::AtomicPropositionLabelingParser(uint_fast64_t n } } -} //namespace parser -} //namespace storm +} // namespace parser +} // namespace storm diff --git a/src/parser/AutoParser.cpp b/src/parser/AutoParser.cpp new file mode 100644 index 000000000..f2d161559 --- /dev/null +++ b/src/parser/AutoParser.cpp @@ -0,0 +1,70 @@ +#include "src/parser/AutoParser.h" + +#include +#include + +#include "src/exceptions/WrongFileFormatException.h" +#include "src/models/AbstractModel.h" +#include "src/parser/DtmcParser.h" +#include "src/parser/MdpParser.h" + +namespace storm { +namespace parser { + +AutoParser::AutoParser(std::string const & transitionSystemFile, std::string const & labelingFile, + std::string const & stateRewardFile, std::string const & transitionRewardFile) + : model(nullptr) { + storm::models::ModelType type = this->analyzeHint(transitionSystemFile); + + if (type == storm::models::Unknown) { + LOG4CPLUS_ERROR(logger, "Could not determine file type of " << transitionSystemFile << "."); + LOG4CPLUS_ERROR(logger, "The first line of the file should contain a format hint. Please fix your file and try again."); + throw storm::exceptions::WrongFileFormatException() << "Could not determine type of file " << transitionSystemFile; + } else { + LOG4CPLUS_INFO(logger, "Model type seems to be " << type); + } + + // Do actual parsing + switch (type) { + case storm::models::DTMC: { + DtmcParser* parser = new DtmcParser(transitionSystemFile, labelingFile, stateRewardFile, transitionRewardFile); + this->model = parser->getDtmc(); + break; + } + case storm::models::CTMC: + break; + case storm::models::MDP: { + MdpParser* parser = new MdpParser(transitionSystemFile, labelingFile, stateRewardFile, transitionRewardFile); + this->model = parser->getMdp(); + break; + } + case storm::models::CTMDP: + break; + default: ; // Unknown + } + + if (!this->model) std::cout << "model is still null" << std::endl; +} + +storm::models::ModelType AutoParser::analyzeHint(const std::string& filename) { + storm::models::ModelType hintType = storm::models::Unknown; + // Open file + MappedFile file(filename.c_str()); + char* buf = file.data; + + // parse hint + char hint[128]; + sscanf(buf, "%s\n", hint); + for (char* c = hint; *c != '\0'; c++) *c = toupper(*c); + + // check hint + if (strncmp(hint, "DTMC", sizeof(hint)) == 0) hintType = storm::models::DTMC; + else if (strncmp(hint, "CTMC", sizeof(hint)) == 0) hintType = storm::models::CTMC; + else if (strncmp(hint, "MDP", sizeof(hint)) == 0) hintType = storm::models::MDP; + else if (strncmp(hint, "CTMDP", sizeof(hint)) == 0) hintType = storm::models::CTMDP; + + return hintType; +} + +} // namespace parser +} // namespace storm diff --git a/src/parser/AutoParser.h b/src/parser/AutoParser.h new file mode 100644 index 000000000..d609f10dc --- /dev/null +++ b/src/parser/AutoParser.h @@ -0,0 +1,61 @@ +#ifndef STORM_PARSER_AUTOPARSER_H_ +#define STORM_PARSER_AUTOPARSER_H_ + +#include "src/parser/Parser.h" +#include "src/models/AbstractModel.h" + +#include +#include +#include + +namespace storm { +namespace parser { + +/*! + * @brief Checks the given files and parses the model within these files. + * + * This parser analyzes the format hitn in the first line of the transition + * file. If this is a valid format, it will use the parser for this format, + * otherwise it will throw an exception. + * + * When the files are parsed successfully, the parsed ModelType and Model + * can be obtained via getType() and getModel(). + */ +class AutoParser : Parser { + public: + AutoParser(std::string const & transitionSystemFile, std::string const & labelingFile, + std::string const & stateRewardFile = "", std::string const & transitionRewardFile = ""); + + /*! + * @brief Returns the type of model that was parsed. + */ + storm::models::ModelType getType() { + if (this->model) return this->model->getType(); + else return storm::models::Unknown; + } + + /*! + * @brief Returns the model with the given type. + */ + template + std::shared_ptr getModel() { + return this->model->as(); + } + + private: + + /*! + * @brief Open file and read file format hint. + */ + storm::models::ModelType analyzeHint(const std::string& filename); + + /*! + * @brief Pointer to a parser that has parsed the given transition system. + */ + std::shared_ptr model; +}; + +} // namespace parser +} // namespace storm + +#endif /* STORM_PARSER_AUTOPARSER_H_ */ diff --git a/src/parser/AutoTransitionParser.cpp b/src/parser/AutoTransitionParser.cpp deleted file mode 100644 index 517f72850..000000000 --- a/src/parser/AutoTransitionParser.cpp +++ /dev/null @@ -1,80 +0,0 @@ -#include "src/parser/AutoTransitionParser.h" - -#include "src/exceptions/WrongFileFormatException.h" - -#include "DeterministicSparseTransitionParser.h" -#include "NonDeterministicSparseTransitionParser.h" - -namespace storm { -namespace parser { - -AutoTransitionParser::AutoTransitionParser(const std::string& filename) - : type(Unknown) { - - TransitionType name = this->analyzeFilename(filename); - std::pair content = this->analyzeContent(filename); - TransitionType hint = content.first, transitions = content.second; - - if (hint == Unknown) { - if (name == transitions) this->type = name; - else { - LOG4CPLUS_ERROR(logger, "Could not determine file type of " << filename << ". Filename suggests " << name << " but transitions look like " << transitions); - LOG4CPLUS_ERROR(logger, "Please fix your file and try again."); - throw storm::exceptions::WrongFileFormatException() << "Could not determine type of file " << filename; - } - } else { - if ((hint == name) && (name == transitions)) this->type = name; - else if (hint == name) { - LOG4CPLUS_WARN(logger, "Transition format in file " << filename << " of type " << name << " look like " << transitions << " transitions."); - LOG4CPLUS_WARN(logger, "We will use the parser for " << name << " and hope for the best!"); - this->type = name; - } - else if (hint == transitions) { - LOG4CPLUS_WARN(logger, "File extension of " << filename << " suggests type " << name << " but the content seems to be " << hint); - LOG4CPLUS_WARN(logger, "We will use the parser for " << hint << " and hope for the best!"); - this->type = hint; - } - else if (name == transitions) { - LOG4CPLUS_WARN(logger, "File " << filename << " contains a hint that it is " << hint << " but filename and transition pattern suggests " << name); - LOG4CPLUS_WARN(logger, "We will use the parser for " << name << " and hope for the best!"); - this->type = name; - } - else { - LOG4CPLUS_WARN(logger, "File " << filename << " contains a hint that it is " << hint << " but filename suggests " << name << " and transition pattern suggests " << transitions); - LOG4CPLUS_WARN(logger, "We will stick to the hint, use the parser for " << hint << " and hope for the best!"); - this->type = hint; - } - } - - // Do actual parsing - switch (this->type) { - case DTMC: - this->parser = new DeterministicSparseTransitionParser(filename); - break; - case NDTMC: - this->parser = new NonDeterministicSparseTransitionParser(filename); - break; - default: ; // Unknown - } -} - -TransitionType AutoTransitionParser::analyzeFilename(const std::string& filename) { - TransitionType type = Unknown; - - return type; -} - -std::pair AutoTransitionParser::analyzeContent(const std::string& filename) { - - TransitionType hintType = Unknown, transType = Unknown; - // Open file - MappedFile file(filename.c_str()); - //char* buf = file.data; - - - return std::pair(hintType, transType); -} - -} //namespace parser - -} //namespace storm diff --git a/src/parser/AutoTransitionParser.h b/src/parser/AutoTransitionParser.h deleted file mode 100644 index 334173975..000000000 --- a/src/parser/AutoTransitionParser.h +++ /dev/null @@ -1,88 +0,0 @@ -#ifndef STORM_PARSER_AUTOPARSER_H_ -#define STORM_PARSER_AUTOPARSER_H_ - -#include "src/models/AtomicPropositionsLabeling.h" -#include "boost/integer/integer_mask.hpp" - -#include "src/parser/Parser.h" - -#include -#include -#include - -namespace storm { -namespace parser { - -/*! - * @brief Enumeration of all supported types of transition systems. - */ -enum TransitionType { - Unknown, DTMC, NDTMC -}; - -std::ostream& operator<<(std::ostream& os, const TransitionType type) -{ - switch (type) { - case Unknown: os << "Unknown"; break; - case DTMC: os << "DTMC"; break; - case NDTMC: os << "NDTMC"; break; - default: os << "Invalid TransitionType"; - } - return os; -} - -/*! - * @brief Checks the given file and tries to call the correct parser. - * - * This parser analyzes the filename, an optional format hint (in the first - * line of the file) and the transitions within the file. - * - * If all three (or two, if the hint is not given) are consistent, it will - * call the appropriate parser. - * If two guesses are the same but the third one contradicts, it will issue - * a warning to the user and call the (hopefully) appropriate parser. - * If all guesses differ, but a format hint is given, it will issue a - * warning to the user and use the format hint to determine the correct - * parser. - * Otherwise, it will issue an error. - */ -class AutoTransitionParser : Parser { - public: - AutoTransitionParser(const std::string& filename); - - /*! - * @brief Returns the type of transition system that was detected. - */ - TransitionType getTransitionType() { - return this->type; - } - - // TODO: is this actually safe with shared_ptr? - template - T* getParser() { - return dynamic_cast( this->parser ); - } - - ~AutoTransitionParser() { - delete this->parser; - } - private: - - TransitionType analyzeFilename(const std::string& filename); - std::pair analyzeContent(const std::string& filename); - - /*! - * @brief Type of the transition system. - */ - TransitionType type; - - /*! - * @brief Pointer to a parser that has parsed the given transition system. - */ - Parser* parser; -}; - -} // namespace parser -} // namespace storm - -#endif /* STORM_PARSER_AUTOPARSER_H_ */ diff --git a/src/parser/DeterministicSparseTransitionParser.cpp b/src/parser/DeterministicSparseTransitionParser.cpp index f3d3b9d00..32d2a7e9b 100644 --- a/src/parser/DeterministicSparseTransitionParser.cpp +++ b/src/parser/DeterministicSparseTransitionParser.cpp @@ -6,91 +6,101 @@ */ #include "src/parser/DeterministicSparseTransitionParser.h" -#include "src/exceptions/FileIoException.h" -#include "src/exceptions/WrongFileFormatException.h" -#include "boost/integer/integer_mask.hpp" -#include -#include -#include -#include -#include + #include #include #include #include #include +#include +#include +#include +#include +#include +#include + +#include "src/exceptions/FileIoException.h" +#include "src/exceptions/WrongFileFormatException.h" +#include "boost/integer/integer_mask.hpp" +#include "src/utility/Settings.h" + #include "log4cplus/logger.h" #include "log4cplus/loggingmacros.h" extern log4cplus::Logger logger; namespace storm { -namespace parser{ +namespace parser { /*! * @brief Perform first pass through the file and obtain number of * non-zero cells and maximum node id. * * This method does the first pass through the .tra file and computes - * the number of non-zero elements that are not diagonal elements, - * which correspondents to the number of transitions that are not - * self-loops. - * (Diagonal elements are treated in a special way). + * the number of non-zero elements. * It also calculates the maximum node id and stores it in maxnode. * - * @return The number of non-zero elements that are not on the diagonal + * @return The number of non-zero elements * @param buf Data to scan. Is expected to be some char array. * @param maxnode Is set to highest id of all nodes. */ -uint_fast64_t DeterministicSparseTransitionParser::firstPass(char* buf, uint_fast64_t &maxnode) { +uint_fast64_t DeterministicSparseTransitionParser::firstPass(char* buf, uint_fast64_t& maxnode) { uint_fast64_t non_zero = 0; - + /* - * check file header and extract number of transitions + * Check file header and extract number of transitions. */ + buf = strchr(buf, '\n') + 1; // skip format hint if (strncmp(buf, "STATES ", 7) != 0) { LOG4CPLUS_ERROR(logger, "Expected \"STATES\" but got \"" << std::string(buf, 0, 16) << "\"."); return 0; } - buf += 7; // skip "STATES " + buf += 7; // skip "STATES " if (strtol(buf, &buf, 10) == 0) return 0; buf = trimWhitespaces(buf); if (strncmp(buf, "TRANSITIONS ", 12) != 0) { LOG4CPLUS_ERROR(logger, "Expected \"TRANSITIONS\" but got \"" << std::string(buf, 0, 16) << "\"."); return 0; } - buf += 12; // skip "TRANSITIONS " + buf += 12; // skip "TRANSITIONS " if ((non_zero = strtol(buf, &buf, 10)) == 0) return 0; - + /* - * check all transitions for non-zero diagonal entrys + * Check all transitions for non-zero diagonal entrys. */ - uint_fast64_t row, col; + uint_fast64_t row, lastrow = 0, col; double val; maxnode = 0; - char* tmp; while (buf[0] != '\0') { /* - * read row and column + * Read row and column. */ row = checked_strtol(buf, &buf); col = checked_strtol(buf, &buf); /* - * check if one is larger than the current maximum id + * Check if one is larger than the current maximum id. */ if (row > maxnode) maxnode = row; if (col > maxnode) maxnode = col; /* - * read value. if value is 0.0, either strtod could not read a number or we encountered a probability of zero. - * if row == col, we have a diagonal element which is treated separately and this non_zero must be decreased. + * Check if a node was skipped, i.e. if a node has no outgoing + * transitions. If so, increase non_zero, as the second pass will + * either throw an exception (in this case, it doesn't matter + * anyway) or add a self-loop (in this case, we'll need the + * additional transition). + */ + if (lastrow < row-1) non_zero += row - lastrow - 1; + lastrow = row; + /* + * Read probability of this transition. + * Check, if the value is a probability, i.e. if it is between 0 and 1. */ - val = strtod(buf, &tmp); - if (val == 0.0) { - LOG4CPLUS_ERROR(logger, "Expected a positive probability but got \"" << std::string(buf, 0, 16) << "\"."); + val = checked_strtod(buf, &buf); + if ((val < 0.0) || (val > 1.0)) { + LOG4CPLUS_ERROR(logger, "Expected a positive probability but got \"" << val << "\"."); return 0; } - if (row == col) non_zero--; - buf = trimWhitespaces(tmp); + buf = trimWhitespaces(buf); } return non_zero; @@ -107,86 +117,102 @@ uint_fast64_t DeterministicSparseTransitionParser::firstPass(char* buf, uint_fas */ DeterministicSparseTransitionParser::DeterministicSparseTransitionParser(std::string const &filename) - : matrix(nullptr) -{ + : matrix(nullptr) { /* - * enforce locale where decimal point is '.' - */ - setlocale( LC_NUMERIC, "C" ); - + * Enforce locale where decimal point is '.'. + */ + setlocale(LC_NUMERIC, "C"); + /* - * open file + * Open file. */ MappedFile file(filename.c_str()); char* buf = file.data; - + /* - * perform first pass, i.e. count entries that are not zero and not on the diagonal + * Perform first pass, i.e. count entries that are not zero. */ uint_fast64_t maxnode; uint_fast64_t non_zero = this->firstPass(file.data, maxnode); /* - * if first pass returned zero, the file format was wrong + * If first pass returned zero, the file format was wrong. */ - if (non_zero == 0) - { + if (non_zero == 0) { LOG4CPLUS_ERROR(logger, "Error while parsing " << filename << ": erroneous file format."); throw storm::exceptions::WrongFileFormatException(); } - + /* - * perform second pass + * Perform second pass- * - * from here on, we already know that the file header is correct + * From here on, we already know that the file header is correct. */ /* - * read file header, extract number of states + * Read file header, extract number of states. */ - buf += 7; // skip "STATES " + buf = strchr(buf, '\n') + 1; // skip format hint + buf += 7; // skip "STATES " checked_strtol(buf, &buf); buf = trimWhitespaces(buf); - buf += 12; // skip "TRANSITIONS " + buf += 12; // skip "TRANSITIONS " checked_strtol(buf, &buf); - + /* - * Creating matrix - * Memory for diagonal elements is automatically allocated, hence only the number of non-diagonal - * non-zero elements has to be specified (which is non_zero, computed by make_first_pass) + * Creating matrix here. + * The number of non-zero elements is computed by firstPass(). */ LOG4CPLUS_INFO(logger, "Attempting to create matrix of size " << (maxnode+1) << " x " << (maxnode+1) << "."); - this->matrix = std::shared_ptr>(new storm::storage::SquareSparseMatrix(maxnode + 1)); - if (this->matrix == NULL) - { + this->matrix = std::shared_ptr>(new storm::storage::SparseMatrix(maxnode + 1)); + if (this->matrix == NULL) { LOG4CPLUS_ERROR(logger, "Could not create matrix of size " << (maxnode+1) << " x " << (maxnode+1) << "."); throw std::bad_alloc(); } this->matrix->initialize(non_zero); - uint_fast64_t row, col; + uint_fast64_t row, lastrow = 0, col; double val; + bool fixDeadlocks = storm::settings::instance()->isSet("fix-deadlocks"); + bool hadDeadlocks = false; /* - * read all transitions from file + * Read all transitions from file. Note that we assume, that the + * transitions are listed in canonical order, otherwise this will not + * work, i.e. the values in the matrix will be at wrong places. */ - while (buf[0] != '\0') - { + while (buf[0] != '\0') { /* - * read row, col and value. + * Read row, col and value. */ row = checked_strtol(buf, &buf); col = checked_strtol(buf, &buf); - val = strtod(buf, &buf); - - this->matrix->addNextValue(row,col,val); + val = checked_strtod(buf, &buf); + + /* + * Check if we skipped a node, i.e. if a node does not have any + * outgoing transitions. + */ + for (uint_fast64_t node = lastrow + 1; node < row; node++) { + hadDeadlocks = true; + if (fixDeadlocks) { + this->matrix->addNextValue(node, node, 1); + LOG4CPLUS_WARN(logger, "Warning while parsing " << filename << ": node " << node << " has no outgoing transitions. A self-loop was inserted."); + } else { + LOG4CPLUS_ERROR(logger, "Error while parsing " << filename << ": node " << node << " has no outgoing transitions."); + } + } + lastrow = row; + + this->matrix->addNextValue(row, col, val); buf = trimWhitespaces(buf); } - + if (!fixDeadlocks && hadDeadlocks) throw storm::exceptions::WrongFileFormatException() << "Some of the nodes had deadlocks. You can use --fix-deadlocks to insert self-loops on the fly."; + /* - * clean up + * Finalize Matrix. */ this->matrix->finalize(); } -} //namespace parser -} //namespace storm +} // namespace parser +} // namespace storm diff --git a/src/parser/DeterministicSparseTransitionParser.h b/src/parser/DeterministicSparseTransitionParser.h index a5b8560de..1d699d0c8 100644 --- a/src/parser/DeterministicSparseTransitionParser.h +++ b/src/parser/DeterministicSparseTransitionParser.h @@ -1,7 +1,7 @@ #ifndef STORM_PARSER_TRAPARSER_H_ #define STORM_PARSER_TRAPARSER_H_ -#include "src/storage/SquareSparseMatrix.h" +#include "src/storage/SparseMatrix.h" #include "src/parser/Parser.h" #include "src/utility/OsDetection.h" @@ -19,12 +19,12 @@ class DeterministicSparseTransitionParser : public Parser { public: DeterministicSparseTransitionParser(std::string const &filename); - std::shared_ptr> getMatrix() { + std::shared_ptr> getMatrix() { return this->matrix; } private: - std::shared_ptr> matrix; + std::shared_ptr> matrix; uint_fast64_t firstPass(char* buf, uint_fast64_t &maxnode); diff --git a/src/parser/DtmcParser.cpp b/src/parser/DtmcParser.cpp index 4f4ed2645..b47b982e8 100644 --- a/src/parser/DtmcParser.cpp +++ b/src/parser/DtmcParser.cpp @@ -5,10 +5,14 @@ * Author: thomas */ -#include "DtmcParser.h" -#include "DeterministicSparseTransitionParser.h" -#include "AtomicPropositionLabelingParser.h" -#include "SparseStateRewardParser.h" +#include "src/parser/DtmcParser.h" + +#include +#include + +#include "src/parser/DeterministicSparseTransitionParser.h" +#include "src/parser/AtomicPropositionLabelingParser.h" +#include "src/parser/SparseStateRewardParser.h" namespace storm { namespace parser { @@ -29,7 +33,7 @@ DtmcParser::DtmcParser(std::string const & transitionSystemFile, std::string con uint_fast64_t stateCount = tp.getMatrix()->getRowCount(); std::shared_ptr> stateRewards = nullptr; - std::shared_ptr> transitionRewards = nullptr; + std::shared_ptr> transitionRewards = nullptr; storm::parser::AtomicPropositionLabelingParser lp(stateCount, labelingFile); if (stateRewardFile != "") { diff --git a/src/parser/DtmcParser.h b/src/parser/DtmcParser.h index b4ab4a3b5..b1f746bb5 100644 --- a/src/parser/DtmcParser.h +++ b/src/parser/DtmcParser.h @@ -5,11 +5,11 @@ * Author: thomas */ -#ifndef DTMCPARSER_H_ -#define DTMCPARSER_H_ +#ifndef STORM_PARSER_DTMCPARSER_H_ +#define STORM_PARSER_DTMCPARSER_H_ -#include "Parser.h" -#include "models/Dtmc.h" +#include "src/parser/Parser.h" +#include "src/models/Dtmc.h" namespace storm { namespace parser { @@ -37,4 +37,4 @@ private: } /* namespace parser */ } /* namespace storm */ -#endif /* DTMCPARSER_H_ */ +#endif /* STORM_PARSER_DTMCPARSER_H_ */ diff --git a/src/parser/MdpParser.cpp b/src/parser/MdpParser.cpp new file mode 100644 index 000000000..f73c0b515 --- /dev/null +++ b/src/parser/MdpParser.cpp @@ -0,0 +1,53 @@ +/* + * MdpParser.cpp + * + * Created on: 14.01.2013 + * Author: Philipp Berger + */ + +#include "src/parser/MdpParser.h" + +#include +#include + +#include "src/parser/NonDeterministicSparseTransitionParser.h" +#include "src/parser/AtomicPropositionLabelingParser.h" +#include "src/parser/SparseStateRewardParser.h" + +namespace storm { +namespace parser { + +/*! + * Parses a transition file and a labeling file and produces a MDP out of them; a pointer to the mdp + * is saved in the field "mdp" + * Note that the labeling file may have at most as many nodes as the transition file! + * + * @param transitionSystemFile String containing the location of the transition file (....tra) + * @param labelingFile String containing the location of the labeling file (....lab) + * @param stateRewardFile String containing the location of the state reward file (...srew) + * @param transitionRewardFile String containing the location of the transition reward file (...trew) + */ +MdpParser::MdpParser(std::string const & transitionSystemFile, std::string const & labelingFile, + std::string const & stateRewardFile, std::string const & transitionRewardFile) { + storm::parser::NonDeterministicSparseTransitionParser tp(transitionSystemFile); + uint_fast64_t stateCount = tp.getMatrix()->getRowCount(); + + std::shared_ptr> stateRewards = nullptr; + std::shared_ptr> transitionRewards = nullptr; + + storm::parser::AtomicPropositionLabelingParser lp(stateCount, labelingFile); + if (stateRewardFile != "") { + storm::parser::SparseStateRewardParser srp(stateCount, stateRewardFile); + stateRewards = srp.getStateRewards(); + } + if (transitionRewardFile != "") { + storm::parser::NonDeterministicSparseTransitionParser trp(transitionRewardFile); + transitionRewards = trp.getMatrix(); + } + + mdp = std::shared_ptr>(new storm::models::Mdp(tp.getMatrix(), lp.getLabeling(), stateRewards, transitionRewards)); +} + +} /* namespace parser */ + +} /* namespace storm */ diff --git a/src/parser/MdpParser.h b/src/parser/MdpParser.h new file mode 100644 index 000000000..e64356dc9 --- /dev/null +++ b/src/parser/MdpParser.h @@ -0,0 +1,40 @@ +/* + * MdpParser.h + * + * Created on: 14.01.2013 + * Author: thomas + */ + +#ifndef STORM_PARSER_MDPPARSER_H_ +#define STORM_PARSER_MDPPARSER_H_ + +#include "src/parser/Parser.h" +#include "src/models/Mdp.h" + +namespace storm { +namespace parser { + +/*! + * @brief Load label and transition file and return initialized mdp object + * + * @Note This class creates a new Mdp object that can + * be accessed via getMdp(). However, it will not delete this object! + * + * @Note The labeling representation in the file may use at most as much nodes as are specified in the mdp. + */ +class MdpParser: public storm::parser::Parser { +public: + MdpParser(std::string const & transitionSystemFile, std::string const & labelingFile, + std::string const & stateRewardFile = "", std::string const & transitionRewardFile = ""); + + std::shared_ptr> getMdp() { + return this->mdp; + } + +private: + std::shared_ptr> mdp; +}; + +} /* namespace parser */ +} /* namespace storm */ +#endif /* STORM_PARSER_MDPPARSER_H_ */ diff --git a/src/parser/NonDeterministicSparseTransitionParser.cpp b/src/parser/NonDeterministicSparseTransitionParser.cpp index 822235c11..394006f99 100644 --- a/src/parser/NonDeterministicSparseTransitionParser.cpp +++ b/src/parser/NonDeterministicSparseTransitionParser.cpp @@ -6,106 +6,141 @@ */ #include "src/parser/NonDeterministicSparseTransitionParser.h" -#include "src/exceptions/FileIoException.h" -#include "src/exceptions/WrongFileFormatException.h" -#include "boost/integer/integer_mask.hpp" -#include -#include -#include -#include -#include + #include #include #include #include #include +#include +#include +#include +#include +#include +#include +#include + +#include "src/utility/Settings.h" +#include "src/exceptions/FileIoException.h" +#include "src/exceptions/WrongFileFormatException.h" +#include "boost/integer/integer_mask.hpp" #include "log4cplus/logger.h" #include "log4cplus/loggingmacros.h" extern log4cplus::Logger logger; namespace storm { -namespace parser{ +namespace parser { /*! - * @brief Perform first pass through the file and obtain number of - * non-zero cells and maximum node id. + * @brief Perform first pass through the file and obtain overall number of + * choices, number of non-zero cells and maximum node id. * - * This method does the first pass through the .tra file and computes - * the number of non-zero elements that are not diagonal elements, - * which correspondents to the number of transitions that are not - * self-loops. - * (Diagonal elements are treated in a special way). - * It also calculates the maximum node id and stores it in maxnode. - * It also stores the maximum number of nondeterministic choices for a - * single single node in maxchoices. + * This method does the first pass through the transition file. + * + * It computes the overall number of nondeterministic choices, i.e. the + * number of rows in the matrix that should be created. + * It also calculates the overall number of non-zero cells, i.e. the number + * of elements the matrix has to hold, and the maximum node id, i.e. the + * number of columns of the matrix. * - * @return The number of non-zero elements that are not on the diagonal * @param buf Data to scan. Is expected to be some char array. + * @param choices Overall number of choices. * @param maxnode Is set to highest id of all nodes. + * @return The number of non-zero elements. */ -std::unique_ptr> NonDeterministicSparseTransitionParser::firstPass(char* buf, uint_fast64_t &maxnode, uint_fast64_t &maxchoice) { - std::unique_ptr> non_zero = std::unique_ptr>(new std::vector()); - +uint_fast64_t NonDeterministicSparseTransitionParser::firstPass(char* buf, uint_fast64_t& choices, uint_fast64_t& maxnode) { /* - * check file header and extract number of transitions + * Check file header and extract number of transitions. */ + buf = strchr(buf, '\n') + 1; // skip format hint if (strncmp(buf, "STATES ", 7) != 0) { LOG4CPLUS_ERROR(logger, "Expected \"STATES\" but got \"" << std::string(buf, 0, 16) << "\"."); - return nullptr; + return 0; } - buf += 7; // skip "STATES " + buf += 7; // skip "STATES " if (strtol(buf, &buf, 10) == 0) return 0; buf = trimWhitespaces(buf); if (strncmp(buf, "TRANSITIONS ", 12) != 0) { LOG4CPLUS_ERROR(logger, "Expected \"TRANSITIONS\" but got \"" << std::string(buf, 0, 16) << "\"."); return 0; } - buf += 12; // skip "TRANSITIONS " - strtol(buf, &buf, 10); - + buf += 12; // skip "TRANSITIONS " /* - * check all transitions for non-zero diagonal entrys + * Parse number of transitions. + * We will not actually use this value, but we will compare it to the + * number of transitions we count and issue a warning if this parsed + * vlaue is wrong. */ - uint_fast64_t row, col, ndchoice; + uint_fast64_t parsed_nonzero = strtol(buf, &buf, 10); + + /* + * Read all transitions. + */ + uint_fast64_t source, target; + uint_fast64_t lastsource = 0; + uint_fast64_t nonzero = 0; double val; + choices = 0; maxnode = 0; - maxchoice = 0; - char* tmp; while (buf[0] != '\0') { /* - * read row and column - */ - row = checked_strtol(buf, &buf); - ndchoice = checked_strtol(buf, &buf); - col = checked_strtol(buf, &buf); - /* - * check if one is larger than the current maximum id - */ - if (row > maxnode) maxnode = row; - if (col > maxnode) maxnode = col; - /* - * check if nondeterministic choice is larger than current maximum + * Read source node. + * Check if current source node is larger than current maximum node id. + * Increase number of choices. + * Check if we have skipped any source node, i.e. if any node has no + * outgoing transitions. If so, increase nonzero (and + * parsed_nonzero). */ - if (ndchoice > maxchoice) - { - maxchoice = ndchoice; - while (non_zero->size() < maxchoice) non_zero->push_back(0); + source = checked_strtol(buf, &buf); + if (source > maxnode) maxnode = source; + choices++; + if (source > lastsource + 1) { + nonzero += source - lastsource - 1; + parsed_nonzero += source - lastsource - 1; } + lastsource = source; + buf = trimWhitespaces(buf); // Skip to name of choice + buf += strcspn(buf, " \t\n\r"); // Skip name of choice. + /* - * read value. if value is 0.0, either strtod could not read a number or we encountered a probability of zero. - * if row == col, we have a diagonal element which is treated separately and this non_zero must be decreased. + * Read all targets for this choice. */ - val = strtod(buf, &tmp); - if (val == 0.0) { - LOG4CPLUS_ERROR(logger, "Expected a positive probability but got \"" << std::string(buf, 0, 16) << "\"."); - return 0; + buf = trimWhitespaces(buf); + while (buf[0] == '*') { + buf++; + /* + * Read target node and transition value. + * Check if current target node is larger than current maximum node id. + * Check if the transition value is a valid probability. + */ + target = checked_strtol(buf, &buf); + if (target > maxnode) maxnode = target; + val = checked_strtod(buf, &buf); + if ((val < 0.0) || (val > 1.0)) { + LOG4CPLUS_ERROR(logger, "Expected a positive probability but got \"" << std::string(buf, 0, 16) << "\"."); + return 0; + } + + /* + * Increase number of non-zero values. + */ + nonzero++; + + /* + * Proceed to beginning of next line. + */ + buf = trimWhitespaces(buf); } - if (row != col) (*non_zero)[ndchoice-1]++; - buf = trimWhitespaces(tmp); } - return non_zero; + /* + * Check if the number of transitions given in the file is correct. + */ + if (nonzero != parsed_nonzero) { + LOG4CPLUS_WARN(logger, "File states to have " << parsed_nonzero << " transitions, but I counted " << nonzero << ". Maybe want to fix your file?"); + } + return nonzero; } @@ -119,89 +154,144 @@ std::unique_ptr> NonDeterministicSparseTransitionPars */ NonDeterministicSparseTransitionParser::NonDeterministicSparseTransitionParser(std::string const &filename) - : matrix(nullptr) -{ + : matrix(nullptr) { /* - * enforce locale where decimal point is '.' - */ - setlocale( LC_NUMERIC, "C" ); - + * Enforce locale where decimal point is '.'. + */ + setlocale(LC_NUMERIC, "C"); + /* - * open file + * Open file. */ MappedFile file(filename.c_str()); char* buf = file.data; - + /* - * perform first pass, i.e. count entries that are not zero and not on the diagonal + * Perform first pass, i.e. obtain number of columns, rows and non-zero elements. */ - uint_fast64_t maxnode, maxchoices; - std::unique_ptr> non_zero = this->firstPass(file.data, maxnode, maxchoices); - + uint_fast64_t maxnode, choices; + uint_fast64_t nonzero = this->firstPass(file.data, choices, maxnode); + /* - * if first pass returned zero, the file format was wrong + * If first pass returned zero, the file format was wrong. */ - if (non_zero == nullptr) - { + if (nonzero == 0) { LOG4CPLUS_ERROR(logger, "Error while parsing " << filename << ": erroneous file format."); throw storm::exceptions::WrongFileFormatException(); } - + /* - * perform second pass + * Perform second pass. * - * from here on, we already know that the file header is correct + * From here on, we already know that the file header is correct. */ /* - * read file header, extract number of states + * Read file header, ignore values within. */ - buf += 7; // skip "STATES " + buf = strchr(buf, '\n') + 1; // skip format hint + buf += 7; // skip "STATES " checked_strtol(buf, &buf); buf = trimWhitespaces(buf); - buf += 12; // skip "TRANSITIONS " + buf += 12; // skip "TRANSITIONS " checked_strtol(buf, &buf); - + /* - * Creating matrix - * Memory for diagonal elements is automatically allocated, hence only the number of non-diagonal - * non-zero elements has to be specified (which is non_zero, computed by make_first_pass) + * Create and initialize matrix. + * The matrix should have as many columns as we have nodes and as many rows as we have choices. + * Those two values, as well as the number of nonzero elements, was been calculated in the first run. */ - LOG4CPLUS_INFO(logger, "Attempting to create matrix of size " << (maxnode+1) << " x " << (maxnode+1) << "."); - this->matrix = std::shared_ptr>(new storm::storage::SquareSparseMatrix(maxnode + 1)); - if (this->matrix == NULL) - { - LOG4CPLUS_ERROR(logger, "Could not create matrix of size " << (maxnode+1) << " x " << (maxnode+1) << "."); + LOG4CPLUS_INFO(logger, "Attempting to create matrix of size " << choices << " x " << (maxnode+1) << " with " << nonzero << " entries."); + this->matrix = std::shared_ptr>(new storm::storage::SparseMatrix(choices, maxnode + 1)); + if (this->matrix == nullptr) { + LOG4CPLUS_ERROR(logger, "Could not create matrix of size " << choices << " x " << (maxnode+1) << "."); throw std::bad_alloc(); } - // TODO: put stuff in matrix / matrices. - //this->matrix->initialize(*non_zero); + this->matrix->initialize(nonzero); - uint_fast64_t row, col, ndchoice; + /* + * Create row mapping. + */ + this->rowMapping = std::shared_ptr(new RowMapping()); + + /* + * Parse file content. + */ + uint_fast64_t source, target, lastsource = 0; + uint_fast64_t curRow = 0; + std::string choice; double val; + bool fixDeadlocks = storm::settings::instance()->isSet("fix-deadlocks"); + bool hadDeadlocks = false; /* - * read all transitions from file + * Read all transitions from file. */ - while (buf[0] != '\0') - { + while (buf[0] != '\0') { /* - * read row, col and value. + * Read source node and choice name. + */ + source = checked_strtol(buf, &buf); + buf = trimWhitespaces(buf); // Skip to name of choice + choice = std::string(buf, strcspn(buf, " \t\n\r")); + + /* + * Check if we have skipped any source node, i.e. if any node has no + * outgoing transitions. If so, insert a self-loop. + * Also add self-loops to rowMapping. + */ + for (uint_fast64_t node = lastsource + 1; node < source; node++) { + hadDeadlocks = true; + if (fixDeadlocks) { + this->rowMapping->insert(RowMapping::value_type(curRow, std::pair(node, ""))); + this->matrix->addNextValue(curRow, node, 1); + curRow++; + LOG4CPLUS_WARN(logger, "Warning while parsing " << filename << ": node " << node << " has no outgoing transitions. A self-loop was inserted."); + } else { + LOG4CPLUS_ERROR(logger, "Error while parsing " << filename << ": node " << node << " has no outgoing transitions."); + } + } + lastsource = source; + + /* + * Add this source-choice pair to rowMapping. + */ + this->rowMapping->insert(RowMapping::value_type(curRow, std::pair(source, choice))); + + /* + * Skip name of choice. + */ + buf += strcspn(buf, " \t\n\r"); + + /* + * Read all targets for this choice. */ - row = checked_strtol(buf, &buf); - ndchoice = checked_strtol(buf, &buf); - col = checked_strtol(buf, &buf); - val = strtod(buf, &buf); - - //this->matrix->addNextValue(row,col,val); buf = trimWhitespaces(buf); + while (buf[0] == '*') { + buf++; + /* + * Read target node and transition value. + * Put it into the matrix. + */ + target = checked_strtol(buf, &buf); + val = checked_strtod(buf, &buf); + this->matrix->addNextValue(curRow, target, val); + + /* + * Proceed to beginning of next line in file and next row in matrix. + */ + buf = trimWhitespaces(buf); + } + curRow++; } - + + if (!fixDeadlocks && hadDeadlocks) throw storm::exceptions::WrongFileFormatException() << "Some of the nodes had deadlocks. You can use --fix-deadlocks to insert self-loops on the fly."; + /* - * clean up + * Finalize matrix. */ - //this->matrix->finalize(); + this->matrix->finalize(); } -} //namespace parser -} //namespace storm +} // namespace parser +} // namespace storm diff --git a/src/parser/NonDeterministicSparseTransitionParser.h b/src/parser/NonDeterministicSparseTransitionParser.h index bc144b160..5b487d5a2 100644 --- a/src/parser/NonDeterministicSparseTransitionParser.h +++ b/src/parser/NonDeterministicSparseTransitionParser.h @@ -1,11 +1,13 @@ #ifndef STORM_PARSER_NONDETTRAPARSER_H_ #define STORM_PARSER_NONDETTRAPARSER_H_ -#include "src/storage/SquareSparseMatrix.h" +#include "src/storage/SparseMatrix.h" #include "src/parser/Parser.h" #include "src/utility/OsDetection.h" +#include +#include #include #include @@ -20,14 +22,20 @@ class NonDeterministicSparseTransitionParser : public Parser { public: NonDeterministicSparseTransitionParser(std::string const &filename); - std::shared_ptr> getMatrix() { + inline std::shared_ptr> getMatrix() const { return this->matrix; } + + typedef boost::bimap> RowMapping; + inline std::shared_ptr getRowMapping() const { + return this->rowMapping; + } private: - std::shared_ptr> matrix; + std::shared_ptr> matrix; + std::shared_ptr rowMapping; - std::unique_ptr> firstPass(char* buf, uint_fast64_t &maxnode, uint_fast64_t &maxchoice); + uint_fast64_t firstPass(char* buf, uint_fast64_t& choices, uint_fast64_t& maxnode); }; diff --git a/src/parser/Parser.cpp b/src/parser/Parser.cpp index 0c53f3ffa..d84ed6ab6 100644 --- a/src/parser/Parser.cpp +++ b/src/parser/Parser.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "src/exceptions/FileIoException.h" #include "src/exceptions/WrongFileFormatException.h" @@ -28,6 +29,24 @@ uint_fast64_t storm::parser::Parser::checked_strtol(const char* str, char** end) return res; } +/*! + * Calls strtod() internally and checks if the new pointer is different + * from the original one, i.e. if str != *end. If they are the same, a + * storm::exceptions::WrongFileFormatException will be thrown. + * @param str String to parse + * @param end New pointer will be written there + * @return Result of strtod() + */ +double storm::parser::Parser::checked_strtod(const char* str, char** end) { + double res = strtod(str, end); + if (str == *end) { + LOG4CPLUS_ERROR(logger, "Error while parsing floating point. Next input token is not a number."); + LOG4CPLUS_ERROR(logger, "\tUpcoming input is: \"" << std::string(str, 0, 16) << "\""); + throw storm::exceptions::WrongFileFormatException("Error while parsing floating point. Next input token is not a number."); + } + return res; +} + /*! * Skips spaces, tabs, newlines and carriage returns. Returns pointer * to first char that is not a whitespace. @@ -35,13 +54,10 @@ uint_fast64_t storm::parser::Parser::checked_strtol(const char* str, char** end) * @return pointer to first non-whitespace character */ char* storm::parser::Parser::trimWhitespaces(char* buf) { - /*TODO: Maybe use memcpy to copy all the stuff from the first non-whitespace char - * to the position of the buffer, so we don't have to keep track of 2 pointers. - */ while ((*buf == ' ') || (*buf == '\t') || (*buf == '\n') || (*buf == '\r')) buf++; return buf; } - + /*! * Will stat the given file, open it and map it to memory. * If anything of this fails, an appropriate exception is raised @@ -68,9 +84,9 @@ storm::parser::MappedFile::MappedFile(const char* filename) { LOG4CPLUS_ERROR(logger, "Error in open(" << filename << ")."); throw exceptions::FileIoException("storm::parser::MappedFile Error in open()"); } - - this->data = (char*) mmap(NULL, this->st.st_size, PROT_READ, MAP_PRIVATE, this->file, 0); - if (this->data == (char*)-1) { + + this->data = reinterpret_cast(mmap(NULL, this->st.st_size, PROT_READ, MAP_PRIVATE, this->file, 0)); + if (this->data == reinterpret_cast(-1)) { close(this->file); LOG4CPLUS_ERROR(logger, "Error in mmap(" << filename << ")."); throw exceptions::FileIoException("storm::parser::MappedFile Error in mmap()"); @@ -85,20 +101,20 @@ storm::parser::MappedFile::MappedFile(const char* filename) { LOG4CPLUS_ERROR(logger, "Error in _stat(" << filename << ")."); throw exceptions::FileIoException("storm::parser::MappedFile Error in stat()"); } - + this->file = CreateFileA(filename, GENERIC_READ, 0, NULL, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL); if (this->file == INVALID_HANDLE_VALUE) { LOG4CPLUS_ERROR(logger, "Error in CreateFileA(" << filename << ")."); throw exceptions::FileIoException("storm::parser::MappedFile Error in CreateFileA()"); } - + this->mapping = CreateFileMappingA(this->file, NULL, PAGE_READONLY, (DWORD)(st.st_size >> 32), (DWORD)st.st_size, NULL); if (this->mapping == NULL) { CloseHandle(this->file); LOG4CPLUS_ERROR(logger, "Error in CreateFileMappingA(" << filename << ")."); throw exceptions::FileIoException("storm::parser::MappedFile Error in CreateFileMappingA()"); } - + this->data = static_cast(MapViewOfFile(this->mapping, FILE_MAP_READ, 0, 0, this->st.st_size)); if (this->data == NULL) { CloseHandle(this->mapping); @@ -109,7 +125,7 @@ storm::parser::MappedFile::MappedFile(const char* filename) { this->dataend = this->data + this->st.st_size; #endif } - + /*! * Will unmap the data and close the file. */ diff --git a/src/parser/Parser.h b/src/parser/Parser.h index 9aa8eef79..ef6c38c84 100644 --- a/src/parser/Parser.h +++ b/src/parser/Parser.h @@ -102,6 +102,11 @@ namespace parser { * @brief Parses integer and checks, if something has been parsed. */ uint_fast64_t checked_strtol(const char* str, char** end); + + /*! + * @brief Parses floating point and checks, if something has been parsed. + */ + double checked_strtod(const char* str, char** end); /*! * @brief Skips common whitespaces in a string. diff --git a/src/parser/SparseStateRewardParser.cpp b/src/parser/SparseStateRewardParser.cpp index bac83a2c0..b79b64292 100644 --- a/src/parser/SparseStateRewardParser.cpp +++ b/src/parser/SparseStateRewardParser.cpp @@ -6,21 +6,23 @@ */ #include "src/parser/SparseStateRewardParser.h" -#include "src/exceptions/WrongFileFormatException.h" -#include "src/exceptions/FileIoException.h" -#include "src/utility/OsDetection.h" -#include -#include -#include -#include -#include #include #include #include #include #include +#include +#include +#include +#include +#include +#include +#include +#include "src/exceptions/WrongFileFormatException.h" +#include "src/exceptions/FileIoException.h" +#include "src/utility/OsDetection.h" #include "log4cplus/logger.h" #include "log4cplus/loggingmacros.h" extern log4cplus::Logger logger; @@ -44,7 +46,7 @@ SparseStateRewardParser::SparseStateRewardParser(uint_fast64_t stateCount, std:: // Create state reward vector with given state count. this->stateRewards = std::shared_ptr>(new std::vector(stateCount)); - + { // Now parse state reward assignments. uint_fast64_t state; @@ -54,9 +56,9 @@ SparseStateRewardParser::SparseStateRewardParser(uint_fast64_t stateCount, std:: while (buf[0] != '\0') { // Parse state number and reward value. state = checked_strtol(buf, &buf); - reward = strtod(buf, &buf); + reward = checked_strtod(buf, &buf); if (reward < 0.0) { - LOG4CPLUS_ERROR(logger, "Expected positive probability but got \"" << std::string(buf, 0, 16) << "\"."); + LOG4CPLUS_ERROR(logger, "Expected positive reward value but got \"" << reward << "\"."); throw storm::exceptions::WrongFileFormatException() << "State reward file specifies illegal reward value."; } @@ -67,6 +69,5 @@ SparseStateRewardParser::SparseStateRewardParser(uint_fast64_t stateCount, std:: } } -} //namespace parser - -} //namespace storm +} // namespace parser +} // namespace storm diff --git a/src/storage/JacobiDecomposition.h b/src/storage/JacobiDecomposition.h index d62f36712..f6de68e98 100644 --- a/src/storage/JacobiDecomposition.h +++ b/src/storage/JacobiDecomposition.h @@ -16,7 +16,7 @@ namespace storage { * Forward declaration against Cycle */ template -class SquareSparseMatrix; +class SparseMatrix; /*! @@ -26,7 +26,7 @@ template class JacobiDecomposition { public: - JacobiDecomposition(storm::storage::SquareSparseMatrix * const jacobiLuMatrix, storm::storage::SquareSparseMatrix * const jacobiDInvMatrix) : jacobiLuMatrix(jacobiLuMatrix), jacobiDInvMatrix(jacobiDInvMatrix) { + JacobiDecomposition(storm::storage::SparseMatrix * const jacobiLuMatrix, storm::storage::SparseMatrix * const jacobiDInvMatrix) : jacobiLuMatrix(jacobiLuMatrix), jacobiDInvMatrix(jacobiDInvMatrix) { } ~JacobiDecomposition() { @@ -39,7 +39,7 @@ public: * Ownership stays with this class. * @return A reference to the Jacobi LU Matrix */ - storm::storage::SquareSparseMatrix& getJacobiLUReference() { + storm::storage::SparseMatrix& getJacobiLUReference() { return *(this->jacobiLuMatrix); } @@ -48,7 +48,7 @@ public: * Ownership stays with this class. * @return A reference to the Jacobi D^{-1} Matrix */ - storm::storage::SquareSparseMatrix& getJacobiDInvReference() { + storm::storage::SparseMatrix& getJacobiDInvReference() { return *(this->jacobiDInvMatrix); } @@ -57,7 +57,7 @@ public: * Ownership stays with this class. * @return A pointer to the Jacobi LU Matrix */ - storm::storage::SquareSparseMatrix* getJacobiLU() { + storm::storage::SparseMatrix* getJacobiLU() { return this->jacobiLuMatrix; } @@ -66,7 +66,7 @@ public: * Ownership stays with this class. * @return A pointer to the Jacobi D^{-1} Matrix */ - storm::storage::SquareSparseMatrix* getJacobiDInv() { + storm::storage::SparseMatrix* getJacobiDInv() { return this->jacobiDInvMatrix; } @@ -81,12 +81,12 @@ private: /*! * Pointer to the LU Matrix */ - storm::storage::SquareSparseMatrix *jacobiLuMatrix; + storm::storage::SparseMatrix *jacobiLuMatrix; /*! * Pointer to the D^{-1} Matrix */ - storm::storage::SquareSparseMatrix *jacobiDInvMatrix; + storm::storage::SparseMatrix *jacobiDInvMatrix; }; } // namespace storage diff --git a/src/storage/SquareSparseMatrix.h b/src/storage/SparseMatrix.h similarity index 74% rename from src/storage/SquareSparseMatrix.h rename to src/storage/SparseMatrix.h index ae48cf4ed..433501efb 100644 --- a/src/storage/SquareSparseMatrix.h +++ b/src/storage/SparseMatrix.h @@ -1,10 +1,11 @@ -#ifndef STORM_STORAGE_SQUARESPARSEMATRIX_H_ -#define STORM_STORAGE_SQUARESPARSEMATRIX_H_ +#ifndef STORM_STORAGE_SPARSEMATRIX_H_ +#define STORM_STORAGE_SPARSEMATRIX_H_ #include #include #include #include +#include #include "boost/integer/integer_mask.hpp" #include "src/exceptions/InvalidStateException.h" @@ -31,13 +32,12 @@ namespace storm { namespace storage { /*! - * A sparse matrix class with a constant number of non-zero entries on the non-diagonal fields - * and a separate dense storage for the diagonal elements. + * A sparse matrix class with a constant number of non-zero entries. * NOTE: Addressing *is* zero-based, so the valid range for getValue and addNextValue is 0..(rows - 1) * where rows is the first argument to the constructor. */ template -class SquareSparseMatrix { +class SparseMatrix { public: /*! * Declare adapter classes as friends to use internal data. @@ -73,9 +73,25 @@ public: * Constructs a sparse matrix object with the given number of rows. * @param rows The number of rows of the matrix */ - SquareSparseMatrix(uint_fast64_t rows) - : rowCount(rows), nonZeroEntryCount(0), valueStorage(nullptr), - diagonalStorage(nullptr),columnIndications(nullptr), rowIndications(nullptr), + SparseMatrix(uint_fast64_t rows, uint_fast64_t cols) + : rowCount(rows), colCount(cols), nonZeroEntryCount(0), + internalStatus(MatrixStatus::UnInitialized), currentSize(0), lastRow(0) { } + + /* Sadly, Delegate Constructors are not yet available with MSVC2012 */ + //! Constructor + /*! + * Constructs a square sparse matrix object with the given number rows + * @param size The number of rows and cols in the matrix + */ /* + SparseMatrix(uint_fast64_t size) : SparseMatrix(size, size) { } + */ + + //! Constructor + /*! + * Constructs a square sparse matrix object with the given number rows + * @param size The number of rows and cols in the matrix + */ + SparseMatrix(uint_fast64_t size) : rowCount(size), colCount(size), nonZeroEntryCount(0), internalStatus(MatrixStatus::UnInitialized), currentSize(0), lastRow(0) { } //! Copy Constructor @@ -83,8 +99,8 @@ public: * Copy Constructor. Performs a deep copy of the given sparse matrix. * @param ssm A reference to the matrix to be copied. */ - SquareSparseMatrix(const SquareSparseMatrix &ssm) - : rowCount(ssm.rowCount), nonZeroEntryCount(ssm.nonZeroEntryCount), + SparseMatrix(const SparseMatrix &ssm) + : rowCount(ssm.rowCount), colCount(ssm.colCount), nonZeroEntryCount(ssm.nonZeroEntryCount), internalStatus(ssm.internalStatus), currentSize(ssm.currentSize), lastRow(ssm.lastRow) { LOG4CPLUS_WARN(logger, "Invoking copy constructor."); // Check whether copying the matrix is safe. @@ -98,24 +114,12 @@ public: LOG4CPLUS_ERROR(logger, "Unable to allocate internal storage."); throw std::bad_alloc(); } else { - // Now that all storages have been prepared, copy over all - // elements. Start by copying the elements of type value and - // copy them seperately in order to invoke copy their copy - // constructor. This may not be necessary, but it is safer to - // do so in any case. - for (uint_fast64_t i = 0; i < nonZeroEntryCount; ++i) { - // use T() to force use of the copy constructor for complex T types - valueStorage[i] = T(ssm.valueStorage[i]); - } - for (uint_fast64_t i = 0; i < rowCount; ++i) { - // use T() to force use of the copy constructor for complex T types - diagonalStorage[i] = T(ssm.diagonalStorage[i]); - } + std::copy(ssm.valueStorage.begin(), ssm.valueStorage.end(), std::back_inserter(valueStorage)); // The elements that are not of the value type but rather the // index type may be copied directly. - std::copy(ssm.columnIndications, ssm.columnIndications + nonZeroEntryCount, columnIndications); - std::copy(ssm.rowIndications, ssm.rowIndications + rowCount + 1, rowIndications); + std::copy(ssm.columnIndications.begin(), ssm.columnIndications.end(), std::back_inserter(columnIndications)); + std::copy(ssm.rowIndications.begin(), ssm.rowIndications.end(), std::back_inserter(rowIndications)); } } } @@ -124,20 +128,11 @@ public: /*! * Destructor. Performs deletion of the reserved storage arrays. */ - ~SquareSparseMatrix() { + ~SparseMatrix() { setState(MatrixStatus::UnInitialized); - if (valueStorage != nullptr) { - delete[] valueStorage; - } - if (columnIndications != nullptr) { - delete[] columnIndications; - } - if (rowIndications != nullptr) { - delete[] rowIndications; - } - if (diagonalStorage != nullptr) { - delete[] diagonalStorage; - } + valueStorage.resize(0); + columnIndications.resize(0); + rowIndications.resize(0); } /*! @@ -155,11 +150,11 @@ public: triggerErrorState(); LOG4CPLUS_ERROR(logger, "Trying to initialize matrix that is not uninitialized."); throw storm::exceptions::InvalidStateException("Trying to initialize matrix that is not uninitialized."); - } else if (rowCount == 0) { + } else if ((rowCount == 0) || (colCount == 0)) { triggerErrorState(); - LOG4CPLUS_ERROR(logger, "Trying to create initialize a matrix with 0 rows."); - throw storm::exceptions::InvalidArgumentException("Trying to create initialize a matrix with 0 rows."); - } else if (((rowCount * rowCount) - rowCount) < nonZeroEntries) { + LOG4CPLUS_ERROR(logger, "Trying to create initialize a matrix with 0 rows or 0 columns."); + throw storm::exceptions::InvalidArgumentException("Trying to create initialize a matrix with 0 rows or 0 columns."); + } else if ((rowCount * colCount) < nonZeroEntries) { triggerErrorState(); LOG4CPLUS_ERROR(logger, "Trying to initialize a matrix with more non-zero entries than there can be."); throw storm::exceptions::InvalidArgumentException("Trying to initialize a matrix with more non-zero entries than there can be."); @@ -196,37 +191,72 @@ public: throw storm::exceptions::InvalidArgumentException("Trying to initialize from an Eigen matrix that is not in compressed form."); } - // Compute the actual (i.e. non-diagonal) number of non-zero entries. - nonZeroEntryCount = getEigenSparseMatrixCorrectNonZeroEntryCount(eigenSparseMatrix); + if (eigenSparseMatrix.rows() > this->rowCount) { + triggerErrorState(); + LOG4CPLUS_ERROR(logger, "Trying to initialize from an Eigen matrix that has more rows than the target matrix."); + throw storm::exceptions::InvalidArgumentException("Trying to initialize from an Eigen matrix that has more rows than the target matrix."); + } + if (eigenSparseMatrix.cols() > this->colCount) { + triggerErrorState(); + LOG4CPLUS_ERROR(logger, "Trying to initialize from an Eigen matrix that has more columns than the target matrix."); + throw storm::exceptions::InvalidArgumentException("Trying to initialize from an Eigen matrix that has more columns than the target matrix."); + } + + const _Index entryCount = eigenSparseMatrix.nonZeros(); + nonZeroEntryCount = entryCount; lastRow = 0; // Try to prepare the internal storage and throw an error in case of // failure. - if (!prepareInternalStorage()) { - triggerErrorState(); - LOG4CPLUS_ERROR(logger, "Unable to allocate internal storage."); - throw std::bad_alloc(); - } else { - // Get necessary pointers to the contents of the Eigen matrix. - const T* valuePtr = eigenSparseMatrix.valuePtr(); - const _Index* indexPtr = eigenSparseMatrix.innerIndexPtr(); - const _Index* outerPtr = eigenSparseMatrix.outerIndexPtr(); - - // If the given matrix is in RowMajor format, copying can simply - // be done by adding all values in order. - // Direct copying is, however, prevented because we have to - // separate the diagonal entries from others. - if (isEigenRowMajor(eigenSparseMatrix)) { - // Because of the RowMajor format outerSize evaluates to the - // number of rows. - const _Index rowCount = eigenSparseMatrix.outerSize(); - for (_Index row = 0; row < rowCount; ++row) { - for (_Index col = outerPtr[row]; col < outerPtr[row + 1]; ++col) { - addNextValue(row, indexPtr[col], valuePtr[col]); - } + + // Get necessary pointers to the contents of the Eigen matrix. + const T* valuePtr = eigenSparseMatrix.valuePtr(); + const _Index* indexPtr = eigenSparseMatrix.innerIndexPtr(); + const _Index* outerPtr = eigenSparseMatrix.outerIndexPtr(); + + // If the given matrix is in RowMajor format, copying can simply + // be done by adding all values in order. + // Direct copying is, however, prevented because we have to + // separate the diagonal entries from others. + if (isEigenRowMajor(eigenSparseMatrix)) { + // Because of the RowMajor format outerSize evaluates to the + // number of rows. + if (!prepareInternalStorage(false)) { + triggerErrorState(); + LOG4CPLUS_ERROR(logger, "Unable to allocate internal storage."); + throw std::bad_alloc(); + } else { + if ((eigenSparseMatrix.innerSize() > nonZeroEntryCount) || (entryCount > nonZeroEntryCount)) { + triggerErrorState(); + LOG4CPLUS_ERROR(logger, "Invalid internal composition of Eigen Sparse Matrix"); + throw storm::exceptions::InvalidArgumentException("Invalid internal composition of Eigen Sparse Matrix"); } + std::vector eigenColumnTemp; + std::vector eigenRowTemp; + std::vector eigenValueTemp; + uint_fast64_t outerSize = eigenSparseMatrix.outerSize() + 1; + + for (uint_fast64_t i = 0; i < entryCount; ++i) { + eigenColumnTemp.push_back(indexPtr[i]); + eigenValueTemp.push_back(valuePtr[i]); + } + for (uint_fast64_t i = 0; i < outerSize; ++i) { + eigenRowTemp.push_back(outerPtr[i]); + } + + std::copy(eigenRowTemp.begin(), eigenRowTemp.end(), std::back_inserter(this->rowIndications)); + std::copy(eigenColumnTemp.begin(), eigenColumnTemp.end(), std::back_inserter(this->columnIndications)); + std::copy(eigenValueTemp.begin(), eigenValueTemp.end(), std::back_inserter(this->valueStorage)); + + currentSize = entryCount; + lastRow = rowCount; + } + } else { + if (!prepareInternalStorage()) { + triggerErrorState(); + LOG4CPLUS_ERROR(logger, "Unable to allocate internal storage."); + throw std::bad_alloc(); } else { - const _Index entryCount = eigenSparseMatrix.nonZeros(); // Because of the ColMajor format outerSize evaluates to the // number of columns. const _Index colCount = eigenSparseMatrix.outerSize(); @@ -250,8 +280,7 @@ public: // add it in case it is also in the current row. if ((positions[currentColumn] < outerPtr[currentColumn + 1]) && (indexPtr[positions[currentColumn]] == currentRow)) { - addNextValue(currentRow, currentColumn, - valuePtr[positions[currentColumn]]); + addNextValue(currentRow, currentColumn, valuePtr[positions[currentColumn]]); // Remember that we found one more non-zero element. ++i; // Mark this position as "used". @@ -268,8 +297,8 @@ public: } delete[] positions; } - setState(MatrixStatus::Initialized); } + setState(MatrixStatus::Initialized); } /*! @@ -283,30 +312,27 @@ public: void addNextValue(const uint_fast64_t row, const uint_fast64_t col, const T& value) { // Check whether the given row and column positions are valid and throw // error otherwise. - if ((row > rowCount) || (col > rowCount)) { + if ((row > rowCount) || (col > colCount)) { triggerErrorState(); LOG4CPLUS_ERROR(logger, "Trying to add a value at illegal position (" << row << ", " << col << ")."); throw storm::exceptions::OutOfRangeException("Trying to add a value at illegal position."); } - if (row == col) { // Set a diagonal element. - diagonalStorage[row] = value; - } else { // Set a non-diagonal element. - // If we switched to another row, we have to adjust the missing - // entries in the row_indications array. - if (row != lastRow) { - for (uint_fast64_t i = lastRow + 1; i <= row; ++i) { - rowIndications[i] = currentSize; - } - lastRow = row; + + // If we switched to another row, we have to adjust the missing + // entries in the row_indications array. + if (row != lastRow) { + for (uint_fast64_t i = lastRow + 1; i <= row; ++i) { + rowIndications[i] = currentSize; } + lastRow = row; + } - // Finally, set the element and increase the current size. - valueStorage[currentSize] = value; - columnIndications[currentSize] = col; + // Finally, set the element and increase the current size. + valueStorage[currentSize] = value; + columnIndications[currentSize] = col; - ++currentSize; - } + ++currentSize; } /* @@ -355,18 +381,12 @@ public: */ inline bool getValue(uint_fast64_t row, uint_fast64_t col, T* const target) const { // Check for illegal access indices. - if ((row > rowCount) || (col > rowCount)) { + if ((row > rowCount) || (col > colCount)) { LOG4CPLUS_ERROR(logger, "Trying to read a value from illegal position (" << row << ", " << col << ")."); throw storm::exceptions::OutOfRangeException("Trying to read a value from illegal position."); return false; } - // Read elements on the diagonal directly. - if (row == col) { - *target = diagonalStorage[row]; - return true; - } - // In case the element is not on the diagonal, we have to iterate // over the accessed row to find the element. uint_fast64_t rowStart = rowIndications[row]; @@ -405,17 +425,12 @@ public: */ inline T& getValue(uint_fast64_t row, uint_fast64_t col) { // Check for illegal access indices. - if ((row > rowCount) || (col > rowCount)) { + if ((row > rowCount) || (col > colCount)) { LOG4CPLUS_ERROR(logger, "Trying to read a value from illegal position (" << row << ", " << col << ")."); throw storm::exceptions::OutOfRangeException("Trying to read a value from illegal position."); } - // Read elements on the diagonal directly. - if (row == col) { - return diagonalStorage[row]; - } - - // In case the element is not on the diagonal, we have to iterate + // we have to iterate // over the accessed row to find the element. uint_fast64_t rowStart = rowIndications[row]; uint_fast64_t rowEnd = rowIndications[row + 1]; @@ -445,20 +460,18 @@ public: } /*! - * Returns a pointer to the value storage of the matrix. This storage does - * *not* include elements on the diagonal. - * @return A pointer to the value storage of the matrix. + * Returns the number of columns of the matrix. */ - T* getStoragePointer() const { - return valueStorage; + uint_fast64_t getColumnCount() const { + return colCount; } /*! - * Returns a pointer to the storage of elements on the diagonal. - * @return A pointer to the storage of elements on the diagonal. + * Returns a pointer to the value storage of the matrix. + * @return A pointer to the value storage of the matrix. */ - T* getDiagonalStoragePointer() const { - return diagonalStorage; + std::vector const & getStoragePointer() const { + return valueStorage; } /*! @@ -467,17 +480,17 @@ public: * @return A pointer to the array that stores the start indices of non-zero * entries in the value storage for each row. */ - uint_fast64_t* getRowIndicationsPointer() const { + std::vector const & getRowIndicationsPointer() const { return rowIndications; } /*! * Returns a pointer to an array that stores the column of each non-zero - * element that is not on the diagonal. + * element. * @return A pointer to an array that stores the column of each non-zero - * element that is not on the diagonal. + * element. */ - uint_fast64_t* getColumnIndicationsPointer() const { + std::vector const & getColumnIndicationsPointer() const { return columnIndications; } @@ -548,10 +561,6 @@ public: #define STORM_USE_TRIPLETCONVERT # ifdef STORM_USE_TRIPLETCONVERT - // FIXME: Wouldn't it be more efficient to add the elements in - // order including the diagonal elements? Otherwise, Eigen has to - // perform some sorting. - // Prepare the triplet storage. typedef Eigen::Triplet IntTriplet; std::vector tripletList; @@ -572,12 +581,6 @@ public: } } - // Then add the elements on the diagonal. - for (uint_fast64_t i = 0; i < rowCount; ++i) { - if (diagonalStorage[i] == 0) zeroCount++; - tripletList.push_back(IntTriplet(static_cast(i), static_cast(i), diagonalStorage[i])); - } - // Let Eigen create a matrix from the given list of triplets. mat->setFromTriplets(tripletList.begin(), tripletList.end()); @@ -596,10 +599,6 @@ public: rowStart = rowIndications[row]; rowEnd = rowIndications[row + 1]; - // Insert the element on the diagonal. - mat->insert(row, row) = diagonalStorage[row]; - count++; - // Insert the elements that are not on the diagonal while (rowStart < rowEnd) { mat->insert(row, columnIndications[rowStart]) = valueStorage[rowStart]; @@ -628,19 +627,6 @@ public: return nonZeroEntryCount; } - /*! - * Returns the number of non-zero entries on the diagonal. - * @return The number of non-zero entries on the diagonal. - */ - uint_fast64_t getDiagonalNonZeroEntryCount() const { - uint_fast64_t result = 0; - T zero(0); - for (uint_fast64_t i = 0; i < rowCount; ++i) { - if (diagonalStorage[i] != zero) ++result; - } - return result; - } - /*! * This function makes the rows given by the bit vector absorbing. * @param rows A bit vector indicating which rows to make absorbing. @@ -658,7 +644,7 @@ public: /*! * This function makes the given row absorbing. This means that all * entries in will be set to 0 and the value 1 will be written - * to the element on the diagonal. + * to the element on the (pseudo-) diagonal. * @param row The row to be made absorbing. * @returns True iff the operation was successful. */ @@ -675,13 +661,31 @@ public: uint_fast64_t rowStart = rowIndications[row]; uint_fast64_t rowEnd = rowIndications[row + 1]; + if (rowStart >= rowEnd) { + LOG4CPLUS_ERROR(logger, "The row " << row << " can not be made absorbing, no state in row, would have to recreate matrix!"); + throw storm::exceptions::InvalidStateException("A row can not be made absorbing, no state in row, would have to recreate matrix!"); + } + uint_fast64_t pseudoDiagonal = row % colCount; + + bool foundDiagonal = false; while (rowStart < rowEnd) { - valueStorage[rowStart] = storm::utility::constGetZero(); + if (!foundDiagonal && columnIndications[rowStart] >= pseudoDiagonal) { + foundDiagonal = true; + // insert/replace the diagonal entry + columnIndications[rowStart] = pseudoDiagonal; + valueStorage[rowStart] = storm::utility::constGetOne(); + } else { + valueStorage[rowStart] = storm::utility::constGetZero(); + } ++rowStart; } - // Set the element on the diagonal to one. - diagonalStorage[row] = storm::utility::constGetOne(); + if (!foundDiagonal) { + --rowStart; + columnIndications[rowStart] = pseudoDiagonal; + valueStorage[rowStart] = storm::utility::constGetOne(); + } + return true; } @@ -724,7 +728,7 @@ public: * @param constraint A bit vector indicating which rows and columns to drop. * @return A pointer to a sparse matrix that is a sub-matrix of the current one. */ - SquareSparseMatrix* getSubmatrix(storm::storage::BitVector& constraint) const { + SparseMatrix* getSubmatrix(storm::storage::BitVector& constraint) const { LOG4CPLUS_DEBUG(logger, "Creating a sub-matrix with " << constraint.getNumberOfSetBits() << " rows."); // Check for valid constraint. @@ -745,7 +749,7 @@ public: } // Create and initialize resulting matrix. - SquareSparseMatrix* result = new SquareSparseMatrix(constraint.getNumberOfSetBits()); + SparseMatrix* result = new SparseMatrix(constraint.getNumberOfSetBits()); result->initialize(subNonZeroEntries); // Create a temporary array that stores for each index whose bit is set @@ -763,8 +767,6 @@ public: // Copy over selected entries. uint_fast64_t rowCount = 0; for (auto rowIndex : constraint) { - result->addNextValue(rowCount, rowCount, diagonalStorage[rowIndex]); - for (uint_fast64_t i = rowIndications[rowIndex]; i < rowIndications[rowIndex + 1]; ++i) { if (constraint.get(columnIndications[i])) { result->addNextValue(rowCount, bitsSetBeforeIndex[columnIndications[i]], valueStorage[i]); @@ -793,9 +795,20 @@ public: * value. */ void invertDiagonal() { + if (this->getRowCount() != this->getColumnCount()) { + throw storm::exceptions::InvalidArgumentException() << "SparseMatrix::invertDiagonal requires the Matrix to be square!"; + } T one(1); - for (uint_fast64_t i = 0; i < rowCount; ++i) { - diagonalStorage[i] = one - diagonalStorage[i]; + for (uint_fast64_t row = 0; row < rowCount; ++row) { + uint_fast64_t rowStart = rowIndications[row]; + uint_fast64_t rowEnd = rowIndications[row + 1]; + while (rowStart < rowEnd) { + if (columnIndications[rowStart] == row) { + valueStorage[rowStart] = one - valueStorage[rowStart]; + break; + } + ++rowStart; + } } } @@ -803,8 +816,18 @@ public: * Negates all non-zero elements that are not on the diagonal. */ void negateAllNonDiagonalElements() { - for (uint_fast64_t i = 0; i < nonZeroEntryCount; ++i) { - valueStorage[i] = - valueStorage[i]; + if (this->getRowCount() != this->getColumnCount()) { + throw storm::exceptions::InvalidArgumentException() << "SparseMatrix::invertDiagonal requires the Matrix to be square!"; + } + for (uint_fast64_t row = 0; row < rowCount; ++row) { + uint_fast64_t rowStart = rowIndications[row]; + uint_fast64_t rowEnd = rowIndications[row + 1]; + while (rowStart < rowEnd) { + if (columnIndications[rowStart] != row) { + valueStorage[rowStart] = - valueStorage[rowStart]; + } + ++rowStart; + } } } @@ -814,8 +837,8 @@ public: */ storm::storage::JacobiDecomposition* getJacobiDecomposition() const { uint_fast64_t rowCount = this->getRowCount(); - SquareSparseMatrix *resultLU = new SquareSparseMatrix(this); - SquareSparseMatrix *resultDinv = new SquareSparseMatrix(rowCount); + SparseMatrix *resultLU = new SparseMatrix(this); + SparseMatrix *resultDinv = new SparseMatrix(rowCount); // no entries apart from those on the diagonal resultDinv->initialize(0); // copy diagonal entries to other matrix @@ -836,7 +859,7 @@ public: * @return A vector containing the sum of the elements in each row of the matrix resulting from * pointwise multiplication of the current matrix with the given matrix. */ - std::vector* getPointwiseProductRowSumVector(storm::storage::SquareSparseMatrix const& otherMatrix) { + std::vector* getPointwiseProductRowSumVector(storm::storage::SparseMatrix const& otherMatrix) { // Prepare result. std::vector* result = new std::vector(rowCount); @@ -844,7 +867,6 @@ public: // in case the given matrix does not have a non-zero element at this column position, or // multiply the two entries and add the result to the corresponding position in the vector. for (uint_fast64_t row = 0; row < rowCount && row < otherMatrix.rowCount; ++row) { - (*result)[row] += diagonalStorage[row] * otherMatrix.diagonalStorage[row]; for (uint_fast64_t element = rowIndications[row], nextOtherElement = otherMatrix.rowIndications[row]; element < rowIndications[row + 1] && nextOtherElement < otherMatrix.rowIndications[row + 1]; ++element) { if (columnIndications[element] < otherMatrix.columnIndications[nextOtherElement]) { continue; @@ -868,25 +890,23 @@ public: uint_fast64_t getSizeInMemory() const { uint_fast64_t size = sizeof(*this); // Add value_storage size. - size += sizeof(T) * nonZeroEntryCount; - // Add diagonal_storage size. - size += sizeof(T) * (rowCount + 1); + size += sizeof(T) * valueStorage.capacity(); // Add column_indications size. - size += sizeof(uint_fast64_t) * nonZeroEntryCount; + size += sizeof(uint_fast64_t) * columnIndications.capacity(); // Add row_indications size. - size += sizeof(uint_fast64_t) * (rowCount + 1); + size += sizeof(uint_fast64_t) * rowIndications.capacity(); return size; } /*! * Returns an iterator to the columns of the non-zero entries of the given - * row that are not on the diagonal. + * row. * @param row The row whose columns the iterator will return. * @return An iterator to the columns of the non-zero entries of the given - * row that are not on the diagonal. + * row. */ - constIndexIterator beginConstColumnNoDiagIterator(uint_fast64_t row) const { - return this->columnIndications + this->rowIndications[row]; + constIndexIterator beginConstColumnIterator(uint_fast64_t row) const { + return &(this->columnIndications[0]) + this->rowIndications[row]; } /*! @@ -894,18 +914,18 @@ public: * @param row The row for which the iterator should point to the past-the-end * element. */ - constIndexIterator endConstColumnNoDiagIterator(uint_fast64_t row) const { - return this->columnIndications + this->rowIndications[row + 1]; + constIndexIterator endConstColumnIterator(uint_fast64_t row) const { + return &(this->columnIndications[0]) + this->rowIndications[row + 1]; } /*! * Returns an iterator over the elements of the given row. The iterator - * will include neither the diagonal element nor zero entries. + * will include no zero entries. * @param row The row whose elements the iterator will return. * @return An iterator over the elements of the given row. */ - constIterator beginConstNoDiagIterator(uint_fast64_t row) const { - return this->valueStorage + this->rowIndications[row]; + constIterator beginConstIterator(uint_fast64_t row) const { + return &(this->valueStorage[0]) + this->rowIndications[row]; } /*! * Returns an iterator pointing to the first element after the given @@ -914,32 +934,28 @@ public: * past-the-end element. * @return An iterator to the element after the given row. */ - constIterator endConstNoDiagIterator(uint_fast64_t row) const { - return this->valueStorage + this->rowIndications[row + 1]; + constIterator endConstIterator(uint_fast64_t row) const { + return &(this->valueStorage[0]) + this->rowIndications[row + 1]; } /*! * @brief Calculate sum of all entries in given row. * - * Adds up all values in the given row (including the diagonal value) + * Adds up all values in the given row * and returns the sum. * @param row The row that should be added up. * @return Sum of the row. */ T getRowSum(uint_fast64_t row) const { - T sum = this->diagonalStorage[row]; - for (auto it = this->beginConstNoDiagIterator(row); it != this->endConstNoDiagIterator(row); it++) { + T sum = storm::utility::constGetZero(); + for (auto it = this->beginConstIterator(row); it != this->endConstIterator(row); it++) { sum += *it; } return sum; } void print() const { - std::cout << "diag: --------------------------------" << std::endl; - for (uint_fast64_t i = 0; i < rowCount; ++i) { - std::cout << "(" << i << "," << i << ") = " << diagonalStorage[i] << std::endl; - } - std::cout << "non diag: ----------------------------" << std::endl; + std::cout << "entries: ----------------------------" << std::endl; for (uint_fast64_t i = 0; i < rowCount; ++i) { for (uint_fast64_t j = rowIndications[i]; j < rowIndications[i + 1]; ++j) { std::cout << "(" << i << "," << columnIndications[j] << ") = " << valueStorage[j] << std::endl; @@ -955,31 +971,31 @@ private: uint_fast64_t rowCount; /*! - * The number of non-zero elements that are not on the diagonal. + * The number of columns of the matrix. */ - uint_fast64_t nonZeroEntryCount; + uint_fast64_t colCount; /*! - * Stores all non-zero values that are not on the diagonal. + * The number of non-zero elements. */ - T* valueStorage; + uint_fast64_t nonZeroEntryCount; /*! - * Stores all elements on the diagonal, even the ones that are zero. + * Stores all non-zero values. */ - T* diagonalStorage; + std::vector valueStorage; /*! - * Stores the column for each non-zero element that is not on the diagonal. + * Stores the column for each non-zero element. */ - uint_fast64_t* columnIndications; + std::vector columnIndications; /*! - * Array containing the boundaries (indices) in the value_storage array + * Vector containing the boundaries (indices) in the value_storage array * for each row. All elements of value_storage with indices between the * i-th and the (i+1)-st element of this array belong to row i. */ - uint_fast64_t* rowIndications; + std::vector rowIndications; /*! * The internal status of the matrix. @@ -1017,24 +1033,37 @@ private: /*! * Prepares the internal CSR storage. For this, it requires * non_zero_entry_count and row_count to be set correctly. + * @param alsoPerformAllocation If set to true, all entries are pre-allocated. This is the default. * @return True on success, false otherwise (allocation failed). */ - bool prepareInternalStorage() { - // Set up the arrays for the elements that are not on the diagonal. - valueStorage = new (std::nothrow) T[nonZeroEntryCount](); - columnIndications = new (std::nothrow) uint_fast64_t[nonZeroEntryCount](); - - // Set up the row_indications array and reserve one element more than - // there are rows in order to put a sentinel element at the end, - // which eases iteration process. - rowIndications = new (std::nothrow) uint_fast64_t[rowCount + 1](); - - // Set up the array for the elements on the diagonal. - diagonalStorage = new (std::nothrow) T[rowCount](); + bool prepareInternalStorage(const bool alsoPerformAllocation) { + if (alsoPerformAllocation) { + // Set up the arrays for the elements that are not on the diagonal. + valueStorage.resize(nonZeroEntryCount, storm::utility::constGetZero()); + columnIndications.resize(nonZeroEntryCount, 0); + + // Set up the row_indications vector and reserve one element more than + // there are rows in order to put a sentinel element at the end, + // which eases iteration process. + rowIndications.resize(rowCount + 1, 0); + + // Return whether all the allocations could be made without error. + return ((valueStorage.capacity() >= nonZeroEntryCount) && (columnIndications.capacity() >= nonZeroEntryCount) + && (rowIndications.capacity() >= (rowCount + 1))); + } else { + valueStorage.reserve(nonZeroEntryCount); + columnIndications.reserve(nonZeroEntryCount); + rowIndications.reserve(rowCount + 1); + return true; + } + } - // Return whether all the allocations could be made without error. - return ((valueStorage != NULL) && (columnIndications != NULL) - && (rowIndications != NULL) && (diagonalStorage != NULL)); + /*! + * Shorthand for prepareInternalStorage(true) + * @see prepareInternalStorage(const bool) + */ + bool prepareInternalStorage() { + return this->prepareInternalStorage(true); } /*! @@ -1060,57 +1089,10 @@ private: return false; } - /*! - * Helper function to determine the number of non-zero elements that are - * not on the diagonal of the given Eigen matrix. - * @param eigen_sparse_matrix The Eigen matrix to analyze. - * @return The number of non-zero elements that are not on the diagonal of - * the given Eigen matrix. - */ - template - _Index getEigenSparseMatrixCorrectNonZeroEntryCount(const Eigen::SparseMatrix<_Scalar, _Options, _Index>& eigen_sparse_matrix) const { - const _Index* indexPtr = eigen_sparse_matrix.innerIndexPtr(); - const _Index* outerPtr = eigen_sparse_matrix.outerIndexPtr(); - - const _Index entryCount = eigen_sparse_matrix.nonZeros(); - const _Index outerCount = eigen_sparse_matrix.outerSize(); - - uint_fast64_t diagNonZeros = 0; - - // For RowMajor, row is the current row and col the column and for - // ColMajor, row is the current column and col the row, but this is - // not important as we are only looking for elements on the diagonal. - _Index innerStart = 0; - _Index innerEnd = 0; - _Index innerMid = 0; - for (_Index row = 0; row < outerCount; ++row) { - innerStart = outerPtr[row]; - innerEnd = outerPtr[row + 1] - 1; - - // Now use binary search (but defer equality detection). - while (innerStart < innerEnd) { - innerMid = innerStart + ((innerEnd - innerStart) / 2); - - if (indexPtr[innerMid] < row) { - innerStart = innerMid + 1; - } else { - innerEnd = innerMid; - } - } - - // Check whether we have found an element on the diagonal. - if ((innerStart == innerEnd) && (indexPtr[innerStart] == row)) { - ++diagNonZeros; - } - } - - return static_cast<_Index>(entryCount - diagNonZeros); - } - }; } // namespace storage } // namespace storm -#endif // STORM_STORAGE_SQUARESPARSEMATRIX_H_ +#endif // STORM_STORAGE_SPARSEMATRIX_H_ diff --git a/src/storm.cpp b/src/storm.cpp index 220ae5a02..a8390a7e4 100644 --- a/src/storm.cpp +++ b/src/storm.cpp @@ -20,12 +20,12 @@ #include "storm-config.h" #include "src/models/Dtmc.h" -#include "src/storage/SquareSparseMatrix.h" +#include "src/storage/SparseMatrix.h" #include "src/models/AtomicPropositionsLabeling.h" #include "src/modelChecker/EigenDtmcPrctlModelChecker.h" #include "src/modelChecker/GmmxxDtmcPrctlModelChecker.h" -#include "src/parser/DtmcParser.h" -// #include "src/parser/PrctlParser.h" +#include "src/parser/AutoParser.h" +#include "src/parser/PrctlParser.h" #include "src/solver/GraphAnalyzer.h" #include "src/utility/Settings.h" #include "src/formula/Formulas.h" @@ -219,10 +219,13 @@ void testCheckingSynchronousLeader(storm::models::Dtmc& dtmc, uint_fast6 */ void testChecking() { storm::settings::Settings* s = storm::settings::instance(); - storm::parser::DtmcParser dtmcParser(s->getString("trafile"), s->getString("labfile"), s->getString("staterew"), s->getString("transrew")); - std::shared_ptr> dtmc = dtmcParser.getDtmc(); + storm::parser::AutoParser parser(s->getString("trafile"), s->getString("labfile"), s->getString("staterew"), s->getString("transrew")); - dtmc->printModelInformationToStream(std::cout); + if (parser.getType() == storm::models::DTMC) { + std::shared_ptr> dtmc = parser.getModel>(); + dtmc->printModelInformationToStream(std::cout); + } + else std::cout << "Input is not DTMC" << std::endl; // testCheckingDie(*dtmc); // testCheckingCrowds(*dtmc); diff --git a/src/utility/CommandLine.cpp b/src/utility/CommandLine.cpp index c79948f43..f8f8321b4 100644 --- a/src/utility/CommandLine.cpp +++ b/src/utility/CommandLine.cpp @@ -8,13 +8,11 @@ #include namespace storm { - namespace utility { void printSeparationLine(std::ostream& out) { out << "------------------------------------------------------" << std::endl; } -} // namespace utility - -} // namespace storm +} // namespace utility +} // namespace storm diff --git a/src/utility/IoUtility.cpp b/src/utility/IoUtility.cpp index e3fc7dbaf..1c2f75c3c 100644 --- a/src/utility/IoUtility.cpp +++ b/src/utility/IoUtility.cpp @@ -1,75 +1,71 @@ /* - * IoUtility.cpp - * - * Created on: 17.10.2012 - * Author: Thomas Heinemann - */ +* IoUtility.cpp +* +* Created on: 17.10.2012 +* Author: Thomas Heinemann +*/ #include "src/utility/IoUtility.h" -#include "src/parser/DeterministicSparseTransitionParser.h" -#include "src/parser/AtomicPropositionLabelingParser.h" #include +#include "src/parser/DeterministicSparseTransitionParser.h" +#include "src/parser/AtomicPropositionLabelingParser.h" + namespace storm { -namespace utility { - -void dtmcToDot(storm::models::Dtmc const &dtmc, std::string filename) { - std::shared_ptr> matrix(dtmc.getTransitionProbabilityMatrix()); - double* diagonal_storage = matrix->getDiagonalStoragePointer(); - - std::ofstream file; - file.open(filename); - - file << "digraph dtmc {\n"; - - //Specify the nodes and their labels - for (uint_fast64_t i = 1; i < dtmc.getNumberOfStates(); i++) { - file << "\t" << i << "[label=\"" << i << "\\n{"; - char komma=' '; - std::set propositions = dtmc.getPropositionsForState(i); - for(auto it = propositions.begin(); - it != propositions.end(); - it++) { - file << komma << *it; - komma=','; - } - - file << " }\"];\n"; - - } - - for (uint_fast64_t row = 0; row < dtmc.getNumberOfStates(); row++ ) { - //write diagonal entry/self loop first - if (diagonal_storage[row] != 0) { - file << "\t" << row << " -> " << row << " [label=" << diagonal_storage[row] <<"]\n"; - } - //Then, iterate through the row and write each non-diagonal value into the file - for ( auto it = matrix->beginConstColumnNoDiagIterator(row); - it != matrix->endConstColumnNoDiagIterator(row); - it++) { - double value = 0; - matrix->getValue(row,*it,&value); - file << "\t" << row << " -> " << *it << " [label=" << value << "]\n"; - } - } - - file << "}\n"; - file.close(); -} + namespace utility { -//TODO: Should this stay here or be integrated in the new parser structure? -/*storm::models::Dtmc* parseDTMC(std::string const &tra_file, std::string const &lab_file) { - storm::parser::DeterministicSparseTransitionParser tp(tra_file); - uint_fast64_t node_count = tp.getMatrix()->getRowCount(); + void dtmcToDot(storm::models::Dtmc const &dtmc, std::string filename) { + std::shared_ptr> matrix(dtmc.getTransitionProbabilityMatrix()); + std::ofstream file; + file.open(filename); - storm::parser::AtomicPropositionLabelingParser lp(node_count, lab_file); + file << "digraph dtmc {\n"; - storm::models::Dtmc* result = new storm::models::Dtmc(tp.getMatrix(), lp.getLabeling()); - return result; -}*/ + //Specify the nodes and their labels + for (uint_fast64_t i = 1; i < dtmc.getNumberOfStates(); i++) { + file << "\t" << i << "[label=\"" << i << "\\n{"; + char komma=' '; + std::set propositions = dtmc.getPropositionsForState(i); + for(auto it = propositions.begin(); + it != propositions.end(); + it++) { + file << komma << *it; + komma=','; + } -} + file << " }\"];\n"; + + } + + for (uint_fast64_t row = 0; row < dtmc.getNumberOfStates(); row++ ) { + + //Then, iterate through the row and write each non-diagonal value into the file + for ( auto it = matrix->beginConstColumnIterator(row); + it != matrix->endConstColumnIterator(row); + it++) { + double value = 0; + matrix->getValue(row,*it,&value); + file << "\t" << row << " -> " << *it << " [label=" << value << "]\n"; + } + } + + file << "}\n"; + file.close(); + } + + //TODO: Should this stay here or be integrated in the new parser structure? + /*storm::models::Dtmc* parseDTMC(std::string const &tra_file, std::string const &lab_file) { + storm::parser::DeterministicSparseTransitionParser tp(tra_file); + uint_fast64_t node_count = tp.getMatrix()->getRowCount(); + + storm::parser::AtomicPropositionLabelingParser lp(node_count, lab_file); + + storm::models::Dtmc* result = new storm::models::Dtmc(tp.getMatrix(), lp.getLabeling()); + return result; + }*/ + + } } diff --git a/src/utility/Settings.cpp b/src/utility/Settings.cpp index 0c89d08d8..3c9b3e4d3 100644 --- a/src/utility/Settings.cpp +++ b/src/utility/Settings.cpp @@ -7,13 +7,17 @@ #include "src/utility/Settings.h" -#include "src/exceptions/BaseException.h" +#include +#include +#include +#include +#include +#include "src/exceptions/BaseException.h" #include "log4cplus/logger.h" #include "log4cplus/loggingmacros.h" extern log4cplus::Logger logger; -#include namespace storm { namespace settings { @@ -21,7 +25,7 @@ namespace settings { namespace bpo = boost::program_options; /* - * static initializers + * Static initializers. */ std::unique_ptr storm::settings::Settings::desc = nullptr; std::string storm::settings::Settings::binaryName = ""; @@ -42,37 +46,37 @@ std::map< std::pair, std::shared_ptrinitDescriptions(); - // Take care of positional arguments + // Take care of positional arguments. Settings::positional.add("trafile", 1); Settings::positional.add("labfile", 1); - // Check module triggers, add corresponding options + // Check module triggers, add corresponding options. std::map< std::string, std::list< std::string > > options; - + for (auto it : Settings::modules) { options[it.first.first].push_back(it.first.second); } for (auto it : options) { std::stringstream str; str << "select " << it.first << " module (" << boost::algorithm::join(it.second, ", ") << ")"; - + Settings::desc->add_options() (it.first.c_str(), bpo::value()->default_value(it.second.front()), str.str().c_str()) ; } - - // Perform first parse run + + // Perform first parse run. this->firstRun(argc, argv, filename); - - // Buffer for items to be deleted + + // Buffer for items to be deleted. std::list< std::pair< std::string, std::string > > deleteQueue; - // Check module triggers + // Check module triggers. for (auto it : Settings::modules) { std::pair< std::string, std::string > trigger = it.first; if (this->vm.count(trigger.first)) { @@ -83,17 +87,16 @@ Settings::Settings(const int argc, const char* argv[], const char* filename) { } } for (auto it : deleteQueue) Settings::modules.erase(it); - - - // Stop if help is set + + // Stop if help is set. if (this->vm.count("help") > 0) { return; } - - // Perform second run + + // Perform second run. this->secondRun(argc, argv, filename); - - // Finalize parsed options, check for specified requirements + + // Finalize parsed options, check for specified requirements. bpo::notify(this->vm); LOG4CPLUS_DEBUG(logger, "Finished loading config."); } @@ -117,7 +120,6 @@ Settings::Settings(const int argc, const char* argv[], const char* filename) { /*! * Initially fill options_description objects. - * First puts some generic options, then calls all register Callbacks. */ void Settings::initDescriptions() { LOG4CPLUS_DEBUG(logger, "Initializing descriptions."); @@ -132,6 +134,7 @@ void Settings::initDescriptions() { ("labfile", bpo::value()->required(), "name of the .lab file") ("transrew", bpo::value()->default_value(""), "name of transition reward file") ("staterew", bpo::value()->default_value(""), "name of state reward file") + ("fix-deadlocks", "insert self-loops for states without outgoing transitions") ; } @@ -140,13 +143,13 @@ void Settings::initDescriptions() { * given), but allow for unregistered options, do not check requirements * from options_description objects, do not check positional arguments. */ -void Settings::firstRun(const int argc, const char* argv[], const char* filename) { +void Settings::firstRun(int const argc, char const * const argv[], char const * const filename) { LOG4CPLUS_DEBUG(logger, "Performing first run."); - // parse command line + // Parse command line. bpo::store(bpo::command_line_parser(argc, argv).options(*(Settings::desc)).allow_unregistered().run(), this->vm); /* - * load config file if specified + * Load config file if specified. */ if (this->vm.count("configfile")) { bpo::store(bpo::parse_config_file(this->vm["configfile"].as().c_str(), *(Settings::desc)), this->vm, true); @@ -160,12 +163,12 @@ void Settings::firstRun(const int argc, const char* argv[], const char* filename * given) and check for unregistered options, requirements from * options_description objects and positional arguments. */ -void Settings::secondRun(const int argc, const char* argv[], const char* filename) { +void Settings::secondRun(int const argc, char const * const argv[], char const * const filename) { LOG4CPLUS_DEBUG(logger, "Performing second run."); - // Parse command line + // Parse command line. bpo::store(bpo::command_line_parser(argc, argv).options(*(Settings::desc)).positional(this->positional).run(), this->vm); /* - * load config file if specified + * Load config file if specified. */ if (this->vm.count("configfile")) { bpo::store(bpo::parse_config_file(this->vm["configfile"].as().c_str(), *(Settings::desc)), this->vm, true); @@ -191,5 +194,5 @@ std::ostream& help(std::ostream& os) { return os; } -} // namespace settings -} // namespace storm +} // namespace settings +} // namespace storm diff --git a/src/utility/Settings.h b/src/utility/Settings.h index 37208f990..b85ff9256 100644 --- a/src/utility/Settings.h +++ b/src/utility/Settings.h @@ -51,7 +51,7 @@ namespace settings { * @brief Get value of a generic option. */ template - const T& get(const std::string &name) const { + inline const T& get( std::string const & name) const { if (this->vm.count(name) == 0) throw storm::exceptions::InvalidSettingsException() << "Could not read option " << name << "."; return this->vm[name].as(); } @@ -59,14 +59,14 @@ namespace settings { /*! * @brief Get value of string option */ - const std::string& getString(const std::string &name) const { + inline const std::string& getString(std::string const & name) const { return this->get(name); } /*! * @brief Check if an option is set */ - const bool isSet(const std::string &name) const { + inline const bool isSet(std::string const & name) const { return this->vm.count(name) > 0; } @@ -107,28 +107,28 @@ namespace settings { */ template static void registerModule() { - // get trigger + // Get trigger values. std::pair< std::string, std::string > trigger = T::getOptionTrigger(); - // build description name + // Build description name. std::stringstream str; str << "Options for " << T::getModuleName() << " (" << trigger.first << " = " << trigger.second << ")"; std::shared_ptr desc = std::shared_ptr(new bpo::options_description(str.str())); - // but options + // Put options into description. T::putOptions(desc.get()); - // store + // Store module. Settings::modules[ trigger ] = desc; } friend std::ostream& help(std::ostream& os); friend std::ostream& helpConfigfile(std::ostream& os); friend Settings* instance(); - friend Settings* newInstance(const int argc, const char* argv[], const char* filename); + friend Settings* newInstance(int const argc, char const * const argv[], char const * const filename); private: /*! * @brief Constructor. */ - Settings(const int argc, const char* argv[], const char* filename); + Settings(int const argc, char const * const argv[], char const * const filename); /*! * @brief Initialize options_description object. @@ -138,12 +138,12 @@ namespace settings { /*! * @brief Perform first parser run */ - void firstRun(const int argc, const char* argv[], const char* filename); + void firstRun(int const argc, char const * const argv[], char const * const filename); /*! * @brief Perform second parser run. */ - void secondRun(const int argc, const char* argv[], const char* filename); + void secondRun(int const argc, char const * const argv[], char const * const filename); /*! * @brief Option description for positional arguments on command line. @@ -197,10 +197,10 @@ namespace settings { * * @param argc should be argc passed to main function * @param argv should be argv passed to main function - * @param filename either NULL or name of config file + * @param filename either NULL or name of config file * @return The new instance of Settings. */ - inline Settings* newInstance(const int argc, const char* argv[], const char* filename) { + inline Settings* newInstance(int const argc, char const * const argv[], char const * const filename) { if (Settings::inst != nullptr) delete Settings::inst; Settings::inst = new Settings(argc, argv, filename); return Settings::inst; diff --git a/src/vector/dense_vector.h b/src/vector/dense_vector.h new file mode 100644 index 000000000..b118d8e5d --- /dev/null +++ b/src/vector/dense_vector.h @@ -0,0 +1,183 @@ +#ifndef MRMC_VECTOR_BITVECTOR_H_ +#define MRMC_VECTOR_BITVECTOR_H_ + +#include +#include +#include +#include "boost/integer/integer_mask.hpp" + +#include +#include + +#include "src/exceptions/invalid_state.h" +#include "src/exceptions/invalid_argument.h" +#include "src/exceptions/out_of_range.h" + +namespace mrmc { + +namespace vector { + +//! A Vector +/*! + A bit vector for boolean fields or quick selection schemas on Matrix entries. + Does NOT perform index bound checks! + */ +template +class DenseVector { + public: + //! Constructor + /*! + \param initial_length The initial size of the boolean Array. Can be changed later on via BitVector::resize() + */ + BitVector(uint_fast64_t initial_length) { + bucket_count = initial_length / 64; + if (initial_length % 64 != 0) { + ++bucket_count; + } + bucket_array = new uint_fast64_t[bucket_count](); + + // init all 0 + for (uint_fast64_t i = 0; i < bucket_count; ++i) { + bucket_array[i] = 0; + } + } + + //! Copy Constructor + /*! + Copy Constructor. Creates an exact copy of the source bit vector bv. Modification of either bit vector does not affect the other. + @param bv A reference to the bit vector that should be copied from + */ + BitVector(const BitVector &bv) : bucket_count(bv.bucket_count) + { + pantheios::log_DEBUG("BitVector::CopyCTor: Using Copy() Ctor."); + bucket_array = new uint_fast64_t[bucket_count](); + memcpy(bucket_array, bv.bucket_array, sizeof(uint_fast64_t) * bucket_count); + } + + ~BitVector() { + if (bucket_array != NULL) { + delete[] bucket_array; + } + } + + void resize(uint_fast64_t new_length) { + uint_fast64_t* tempArray = new uint_fast64_t[new_length](); + + // 64 bit/entries per uint_fast64_t + uint_fast64_t copySize = (new_length <= (bucket_count * 64)) ? (new_length/64) : (bucket_count); + memcpy(tempArray, bucket_array, sizeof(uint_fast64_t) * copySize); + + bucket_count = new_length / 64; + if (new_length % 64 != 0) { + ++bucket_count; + } + + delete[] bucket_array; + bucket_array = tempArray; + } + + void set(const uint_fast64_t index, const bool value) { + uint_fast64_t bucket = index / 64; + // Taking the step with mask is crucial as we NEED a 64bit shift, not a 32bit one. + // MSVC: C4334, use 1i64 or cast to __int64. + // result of 32-bit shift implicitly converted to 64 bits (was 64-bit shift intended?) + uint_fast64_t mask = 1; + mask = mask << (index % 64); + if (value) { + bucket_array[bucket] |= mask; + } else { + bucket_array[bucket] &= ~mask; + } + } + + bool get(const uint_fast64_t index) { + uint_fast64_t bucket = index / 64; + // Taking the step with mask is crucial as we NEED a 64bit shift, not a 32bit one. + // MSVC: C4334, use 1i64 or cast to __int64. + // result of 32-bit shift implicitly converted to 64 bits (was 64-bit shift intended?) + uint_fast64_t mask = 1; + mask = mask << (index % 64); + return ((bucket_array[bucket] & mask) == mask); + } + + // Operators + BitVector operator &(BitVector const &bv) { + uint_fast64_t minSize = (bv.bucket_count < this->bucket_count) ? bv.bucket_count : this->bucket_count; + BitVector result(minSize * 64); + for (uint_fast64_t i = 0; i < minSize; ++i) { + result.bucket_array[i] = this->bucket_array[i] & bv.bucket_array[i]; + } + + return result; + } + BitVector operator |(BitVector const &bv) { + uint_fast64_t minSize = (bv.bucket_count < this->bucket_count) ? bv.bucket_count : this->bucket_count; + BitVector result(minSize * 64); + for (uint_fast64_t i = 0; i < minSize; ++i) { + result.bucket_array[i] = this->bucket_array[i] | bv.bucket_array[i]; + } + + return result; + } + + BitVector operator ^(BitVector const &bv) { + uint_fast64_t minSize = (bv.bucket_count < this->bucket_count) ? bv.bucket_count : this->bucket_count; + BitVector result(minSize * 64); + for (uint_fast64_t i = 0; i < minSize; ++i) { + result.bucket_array[i] = this->bucket_array[i] ^ bv.bucket_array[i]; + } + + return result; + } + + BitVector operator ~() { + BitVector result(this->bucket_count * 64); + for (uint_fast64_t i = 0; i < this->bucket_count; ++i) { + result.bucket_array[i] = ~this->bucket_array[i]; + } + + return result; + } + + /*! + * Returns the number of bits that are set (to one) in this bit vector. + * @return The number of bits that are set (to one) in this bit vector. + */ + uint_fast64_t getNumberOfSetBits() { + uint_fast64_t set_bits = 0; + for (uint_fast64_t i = 0; i < bucket_count; i++) { +#ifdef __GNUG__ // check if we are using g++ and use built-in function if available + set_bits += __builtin_popcount (this->bucket_array[i]); +#else + uint_fast32_t cnt; + uint_fast64_t bitset = this->bucket_array[i]; + for (cnt = 0; bitset; cnt++) { + bitset &= bitset - 1; + } + set_bits += cnt; +#endif + } + return set_bits; + } + + /*! + * Returns the size of the bit vector in memory measured in bytes. + * @return The size of the bit vector in memory measured in bytes. + */ + uint_fast64_t getSizeInMemory() { + return sizeof(*this) + sizeof(uint_fast64_t) * bucket_count; + } + + private: + uint_fast64_t bucket_count; + + /*! Array containing the boolean bits for each node, 64bits/64nodes per element */ + uint_fast64_t* bucket_array; + +}; + +} // namespace vector + +} // namespace mrmc + +#endif // MRMC_SPARSE_STATIC_SPARSE_MATRIX_H_ diff --git a/test/parser/ParseMdpTest.cpp b/test/parser/ParseMdpTest.cpp new file mode 100644 index 000000000..4147c0122 --- /dev/null +++ b/test/parser/ParseMdpTest.cpp @@ -0,0 +1,32 @@ +/* + * ParseMdpTest.cpp + * + * Created on: 14.01.2013 + * Author: Thomas Heinemann + */ + + +#include "gtest/gtest.h" +#include "storm-config.h" +#include "src/parser/MdpParser.h" +#include "src/utility/IoUtility.h" + +TEST(ParseMdpTest, parseAndOutput) { + storm::parser::MdpParser* mdpParser; + ASSERT_NO_THROW(mdpParser = new storm::parser::MdpParser( + STORM_CPP_TESTS_BASE_PATH "/parser/tra_files/mdp_general_input_01.tra", + STORM_CPP_TESTS_BASE_PATH "/parser/lab_files/pctl_general_input_01.lab")); + + std::shared_ptr> mdp = mdpParser->getMdp(); + std::shared_ptr> matrix = mdp->getTransitionProbabilityMatrix(); + + ASSERT_EQ(mdp->getNumberOfStates(), 3); + ASSERT_EQ(mdp->getNumberOfTransitions(), 11); + ASSERT_EQ(matrix->getRowCount(), 2 * 3); + ASSERT_EQ(matrix->getColumnCount(), 3); + + + delete mdpParser; +} + + diff --git a/test/parser/ReadTraFileTest.cpp b/test/parser/ReadTraFileTest.cpp index 77377d857..2bda7561e 100644 --- a/test/parser/ReadTraFileTest.cpp +++ b/test/parser/ReadTraFileTest.cpp @@ -7,7 +7,7 @@ #include "gtest/gtest.h" #include "storm-config.h" -#include "src/storage/SquareSparseMatrix.h" +#include "src/storage/SparseMatrix.h" #include "src/parser/DeterministicSparseTransitionParser.h" #include "src/exceptions/FileIoException.h" #include "src/exceptions/WrongFileFormatException.h" @@ -24,7 +24,7 @@ TEST(ReadTraFileTest, NonExistingFileTest) { TEST(ReadTraFileTest, ParseFileTest1) { storm::parser::DeterministicSparseTransitionParser* parser; ASSERT_NO_THROW(parser = new storm::parser::DeterministicSparseTransitionParser(STORM_CPP_TESTS_BASE_PATH "/parser/tra_files/csl_general_input_01.tra")); - std::shared_ptr> result = parser->getMatrix(); + std::shared_ptr> result = parser->getMatrix(); if (result != NULL) { double val = 0; @@ -53,13 +53,13 @@ TEST(ReadTraFileTest, ParseFileTest1) { ASSERT_TRUE(result->getValue(3,2,&val)); ASSERT_EQ(val,0.0806451612903225806451612903225812); - ASSERT_TRUE(result->getValue(3,3,&val)); + ASSERT_FALSE(result->getValue(3,3,&val)); ASSERT_EQ(val,0); ASSERT_TRUE(result->getValue(3,4,&val)); ASSERT_EQ(val,0.080645161290322580645161290322581); - ASSERT_TRUE(result->getValue(4,4,&val)); + ASSERT_FALSE(result->getValue(4,4,&val)); ASSERT_EQ(val,0); delete parser; diff --git a/test/storage/SquareSparseMatrixTest.cpp b/test/storage/SparseMatrixTest.cpp similarity index 58% rename from test/storage/SquareSparseMatrixTest.cpp rename to test/storage/SparseMatrixTest.cpp index 61d64c933..2a29fb736 100644 --- a/test/storage/SquareSparseMatrixTest.cpp +++ b/test/storage/SparseMatrixTest.cpp @@ -1,73 +1,73 @@ #include "gtest/gtest.h" -#include "src/storage/SquareSparseMatrix.h" +#include "src/storage/SparseMatrix.h" #include "src/exceptions/InvalidArgumentException.h" #include "src/exceptions/OutOfRangeException.h" -TEST(SquareSparseMatrixTest, ZeroRowsTest) { - storm::storage::SquareSparseMatrix *ssm = new storm::storage::SquareSparseMatrix(0); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::UnInitialized); +TEST(SparseMatrixTest, ZeroRowsTest) { + storm::storage::SparseMatrix *ssm = new storm::storage::SparseMatrix(0); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::UnInitialized); ASSERT_THROW(ssm->initialize(50), storm::exceptions::InvalidArgumentException); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::Error); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::Error); delete ssm; } -TEST(SquareSparseMatrixTest, TooManyEntriesTest) { - storm::storage::SquareSparseMatrix *ssm = new storm::storage::SquareSparseMatrix(2); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::UnInitialized); +TEST(SparseMatrixTest, TooManyEntriesTest) { + storm::storage::SparseMatrix *ssm = new storm::storage::SparseMatrix(2); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::UnInitialized); ASSERT_THROW(ssm->initialize(10), storm::exceptions::InvalidArgumentException); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::Error); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::Error); delete ssm; } -TEST(SquareSparseMatrixTest, addNextValueTest) { - storm::storage::SquareSparseMatrix *ssm = new storm::storage::SquareSparseMatrix(5); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::UnInitialized); +TEST(SparseMatrixTest, addNextValueTest) { + storm::storage::SparseMatrix *ssm = new storm::storage::SparseMatrix(5); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::UnInitialized); ASSERT_NO_THROW(ssm->initialize(1)); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::Initialized); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::Initialized); ASSERT_THROW(ssm->addNextValue(-1, 1, 1), storm::exceptions::OutOfRangeException); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::Error); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::Error); ASSERT_THROW(ssm->addNextValue(1, -1, 1), storm::exceptions::OutOfRangeException); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::Error); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::Error); ASSERT_THROW(ssm->addNextValue(6, 1, 1), storm::exceptions::OutOfRangeException); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::Error); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::Error); ASSERT_THROW(ssm->addNextValue(1, 6, 1), storm::exceptions::OutOfRangeException); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::Error); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::Error); delete ssm; } -TEST(SquareSparseMatrixTest, finalizeTest) { - storm::storage::SquareSparseMatrix *ssm = new storm::storage::SquareSparseMatrix(5); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::UnInitialized); +TEST(SparseMatrixTest, finalizeTest) { + storm::storage::SparseMatrix *ssm = new storm::storage::SparseMatrix(5); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::UnInitialized); ASSERT_NO_THROW(ssm->initialize(5)); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::Initialized); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::Initialized); ASSERT_NO_THROW(ssm->addNextValue(1, 2, 1)); ASSERT_NO_THROW(ssm->addNextValue(1, 3, 1)); ASSERT_NO_THROW(ssm->addNextValue(1, 4, 1)); ASSERT_NO_THROW(ssm->addNextValue(1, 5, 1)); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::Initialized); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::Initialized); ASSERT_THROW(ssm->finalize(), storm::exceptions::InvalidStateException); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::Error); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::Error); delete ssm; } -TEST(SquareSparseMatrixTest, Test) { +TEST(SparseMatrixTest, Test) { // 25 rows, 50 non zero entries - storm::storage::SquareSparseMatrix *ssm = new storm::storage::SquareSparseMatrix(25); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::UnInitialized); + storm::storage::SparseMatrix *ssm = new storm::storage::SparseMatrix(25); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::UnInitialized); int values[50] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, @@ -96,15 +96,15 @@ TEST(SquareSparseMatrixTest, Test) { }; ASSERT_NO_THROW(ssm->initialize(50)); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::Initialized); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::Initialized); for (int i = 0; i < 50; ++i) { ASSERT_NO_THROW(ssm->addNextValue(position_row[i], position_col[i], values[i])); } - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::Initialized); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::Initialized); ASSERT_NO_THROW(ssm->finalize()); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::ReadReady); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::ReadReady); int target; for (int i = 0; i < 50; ++i) { @@ -116,24 +116,20 @@ TEST(SquareSparseMatrixTest, Test) { for (int row = 15; row < 24; ++row) { for (int col = 1; col <= 25; ++col) { target = 1; - if (row != col) { - ASSERT_FALSE(ssm->getValue(row, col, &target)); - } else { - ASSERT_TRUE(ssm->getValue(row, col, &target)); - } + ASSERT_FALSE(ssm->getValue(row, col, &target)); ASSERT_EQ(0, target); } } - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::ReadReady); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::ReadReady); delete ssm; } -TEST(SquareSparseMatrixTest, ConversionFromDenseEigen_ColMajor_SparseMatrixTest) { +TEST(SparseMatrixTest, ConversionFromDenseEigen_ColMajor_SparseMatrixTest) { // 10 rows, 100 non zero entries - storm::storage::SquareSparseMatrix *ssm = new storm::storage::SquareSparseMatrix(10); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::UnInitialized); + storm::storage::SparseMatrix *ssm = new storm::storage::SparseMatrix(10); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::UnInitialized); Eigen::SparseMatrix esm(10, 10); for (int row = 0; row < 10; ++row) { @@ -149,7 +145,7 @@ TEST(SquareSparseMatrixTest, ConversionFromDenseEigen_ColMajor_SparseMatrixTest) ASSERT_NO_THROW(ssm->finalize()); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::ReadReady); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::ReadReady); int target = -1; for (int row = 0; row < 10; ++row) { @@ -162,10 +158,10 @@ TEST(SquareSparseMatrixTest, ConversionFromDenseEigen_ColMajor_SparseMatrixTest) delete ssm; } -TEST(SquareSparseMatrixTest, ConversionFromDenseEigen_RowMajor_SparseMatrixTest) { +TEST(SparseMatrixTest, ConversionFromDenseEigen_RowMajor_SparseMatrixTest) { // 10 rows, 100 non zero entries - storm::storage::SquareSparseMatrix *ssm = new storm::storage::SquareSparseMatrix(10); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::UnInitialized); + storm::storage::SparseMatrix *ssm = new storm::storage::SparseMatrix(10); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::UnInitialized); Eigen::SparseMatrix esm(10, 10); for (int row = 0; row < 10; ++row) { @@ -181,7 +177,7 @@ TEST(SquareSparseMatrixTest, ConversionFromDenseEigen_RowMajor_SparseMatrixTest) ASSERT_NO_THROW(ssm->finalize()); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::ReadReady); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::ReadReady); int target = -1; for (int row = 0; row < 10; ++row) { @@ -194,10 +190,10 @@ TEST(SquareSparseMatrixTest, ConversionFromDenseEigen_RowMajor_SparseMatrixTest) delete ssm; } -TEST(SquareSparseMatrixTest, ConversionFromSparseEigen_ColMajor_SparseMatrixTest) { +TEST(SparseMatrixTest, ConversionFromSparseEigen_ColMajor_SparseMatrixTest) { // 10 rows, 15 non zero entries - storm::storage::SquareSparseMatrix *ssm = new storm::storage::SquareSparseMatrix(10); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::UnInitialized); + storm::storage::SparseMatrix *ssm = new storm::storage::SparseMatrix(10); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::UnInitialized); Eigen::SparseMatrix esm(10, 10); @@ -231,7 +227,7 @@ TEST(SquareSparseMatrixTest, ConversionFromSparseEigen_ColMajor_SparseMatrixTest ASSERT_NO_THROW(ssm->finalize()); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::ReadReady); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::ReadReady); int target = -1; @@ -243,17 +239,17 @@ TEST(SquareSparseMatrixTest, ConversionFromSparseEigen_ColMajor_SparseMatrixTest delete ssm; } -TEST(SquareSparseMatrixTest, ConversionFromSparseEigen_RowMajor_SparseMatrixTest) { +TEST(SparseMatrixTest, ConversionFromSparseEigen_RowMajor_SparseMatrixTest) { // 10 rows, 15 non zero entries - storm::storage::SquareSparseMatrix *ssm = new storm::storage::SquareSparseMatrix(10); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::UnInitialized); + storm::storage::SparseMatrix *ssm = new storm::storage::SparseMatrix(10, 10); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::UnInitialized); Eigen::SparseMatrix esm(10, 10); typedef Eigen::Triplet IntTriplet; std::vector tripletList; tripletList.reserve(15); - tripletList.push_back(IntTriplet(1, 0, 0)); + tripletList.push_back(IntTriplet(1, 0, 15)); tripletList.push_back(IntTriplet(1, 1, 1)); tripletList.push_back(IntTriplet(1, 2, 2)); tripletList.push_back(IntTriplet(1, 3, 3)); @@ -280,38 +276,42 @@ TEST(SquareSparseMatrixTest, ConversionFromSparseEigen_RowMajor_SparseMatrixTest ASSERT_NO_THROW(ssm->finalize()); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::ReadReady); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::ReadReady); + + const std::vector rowP = ssm->getRowIndicationsPointer(); + const std::vector colP = ssm->getColumnIndicationsPointer(); + const std::vector valP = ssm->getStoragePointer(); int target = -1; - for (auto &coeff: tripletList) { - ASSERT_TRUE(ssm->getValue(coeff.row(), coeff.col(), &target)); + bool retVal = ssm->getValue(coeff.row(), coeff.col(), &target); + ASSERT_TRUE(retVal); ASSERT_EQ(target, coeff.value()); } delete ssm; } -TEST(SquareSparseMatrixTest, ConversionToSparseEigen_RowMajor_SparseMatrixTest) { +TEST(SparseMatrixTest, ConversionToSparseEigen_RowMajor_SparseMatrixTest) { int values[100]; - storm::storage::SquareSparseMatrix *ssm = new storm::storage::SquareSparseMatrix(10); + storm::storage::SparseMatrix *ssm = new storm::storage::SparseMatrix(10); for (uint_fast32_t i = 0; i < 100; ++i) { values[i] = i; } - ASSERT_NO_THROW(ssm->initialize(100 - 10)); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::Initialized); + ASSERT_NO_THROW(ssm->initialize(100)); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::Initialized); for (uint_fast32_t row = 0; row < 10; ++row) { for (uint_fast32_t col = 0; col < 10; ++col) { ASSERT_NO_THROW(ssm->addNextValue(row, col, values[row * 10 + col])); } } - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::Initialized); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::Initialized); ASSERT_NO_THROW(ssm->finalize()); - ASSERT_EQ(ssm->getState(), storm::storage::SquareSparseMatrix::MatrixStatus::ReadReady); + ASSERT_EQ(ssm->getState(), storm::storage::SparseMatrix::MatrixStatus::ReadReady); Eigen::SparseMatrix* esm = ssm->toEigenSparseMatrix(); diff --git a/test/mrmc-tests.cpp b/test/storm-tests.cpp similarity index 100% rename from test/mrmc-tests.cpp rename to test/storm-tests.cpp