1
0
mirror of https://github.com/janeczku/calibre-web synced 2025-11-19 00:15:11 +00:00
This commit is contained in:
Jan Broer
2015-08-02 20:59:11 +02:00
commit 64a9cbce2d
743 changed files with 233749 additions and 0 deletions

View File

@@ -0,0 +1,22 @@
from __future__ import absolute_import
from .warnings import testing_warn, assert_warnings, resetwarnings
from . import config
from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\
fails_on, fails_on_everything_except, skip, only_on, exclude, against,\
_server_version, only_if
from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
eq_, ne_, is_, is_not_, startswith_, assert_raises, \
assert_raises_message, AssertsCompiledSQL, ComparesTables, \
AssertsExecutionResults
from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict
crashes = skip
from .config import db, requirements as requires
from . import mock

View File

@@ -0,0 +1,377 @@
from __future__ import absolute_import
from . import util as testutil
from sqlalchemy import pool, orm, util
from sqlalchemy.engine import default, create_engine
from sqlalchemy import exc as sa_exc
from sqlalchemy.util import decorator
from sqlalchemy import types as sqltypes, schema
import warnings
import re
from .warnings import resetwarnings
from .exclusions import db_spec, _is_excluded
from . import assertsql
from . import config
import itertools
from .util import fail
import contextlib
def emits_warning(*messages):
"""Mark a test as emitting a warning.
With no arguments, squelches all SAWarning failures. Or pass one or more
strings; these will be matched to the root of the warning description by
warnings.filterwarnings().
"""
# TODO: it would be nice to assert that a named warning was
# emitted. should work with some monkeypatching of warnings,
# and may work on non-CPython if they keep to the spirit of
# warnings.showwarning's docstring.
# - update: jython looks ok, it uses cpython's module
@decorator
def decorate(fn, *args, **kw):
# todo: should probably be strict about this, too
filters = [dict(action='ignore',
category=sa_exc.SAPendingDeprecationWarning)]
if not messages:
filters.append(dict(action='ignore',
category=sa_exc.SAWarning))
else:
filters.extend(dict(action='ignore',
message=message,
category=sa_exc.SAWarning)
for message in messages)
for f in filters:
warnings.filterwarnings(**f)
try:
return fn(*args, **kw)
finally:
resetwarnings()
return decorate
def emits_warning_on(db, *warnings):
"""Mark a test as emitting a warning on a specific dialect.
With no arguments, squelches all SAWarning failures. Or pass one or more
strings; these will be matched to the root of the warning description by
warnings.filterwarnings().
"""
spec = db_spec(db)
@decorator
def decorate(fn, *args, **kw):
if isinstance(db, basestring):
if not spec(config.db):
return fn(*args, **kw)
else:
wrapped = emits_warning(*warnings)(fn)
return wrapped(*args, **kw)
else:
if not _is_excluded(*db):
return fn(*args, **kw)
else:
wrapped = emits_warning(*warnings)(fn)
return wrapped(*args, **kw)
return decorate
def uses_deprecated(*messages):
"""Mark a test as immune from fatal deprecation warnings.
With no arguments, squelches all SADeprecationWarning failures.
Or pass one or more strings; these will be matched to the root
of the warning description by warnings.filterwarnings().
As a special case, you may pass a function name prefixed with //
and it will be re-written as needed to match the standard warning
verbiage emitted by the sqlalchemy.util.deprecated decorator.
"""
@decorator
def decorate(fn, *args, **kw):
# todo: should probably be strict about this, too
filters = [dict(action='ignore',
category=sa_exc.SAPendingDeprecationWarning)]
if not messages:
filters.append(dict(action='ignore',
category=sa_exc.SADeprecationWarning))
else:
filters.extend(
[dict(action='ignore',
message=message,
category=sa_exc.SADeprecationWarning)
for message in
[(m.startswith('//') and
('Call to deprecated function ' + m[2:]) or m)
for m in messages]])
for f in filters:
warnings.filterwarnings(**f)
try:
return fn(*args, **kw)
finally:
resetwarnings()
return decorate
def global_cleanup_assertions():
"""Check things that have to be finalized at the end of a test suite.
Hardcoded at the moment, a modular system can be built here
to support things like PG prepared transactions, tables all
dropped, etc.
"""
testutil.lazy_gc()
assert not pool._refs, str(pool._refs)
def eq_(a, b, msg=None):
"""Assert a == b, with repr messaging on failure."""
assert a == b, msg or "%r != %r" % (a, b)
def ne_(a, b, msg=None):
"""Assert a != b, with repr messaging on failure."""
assert a != b, msg or "%r == %r" % (a, b)
def is_(a, b, msg=None):
"""Assert a is b, with repr messaging on failure."""
assert a is b, msg or "%r is not %r" % (a, b)
def is_not_(a, b, msg=None):
"""Assert a is not b, with repr messaging on failure."""
assert a is not b, msg or "%r is %r" % (a, b)
def startswith_(a, fragment, msg=None):
"""Assert a.startswith(fragment), with repr messaging on failure."""
assert a.startswith(fragment), msg or "%r does not start with %r" % (
a, fragment)
def assert_raises(except_cls, callable_, *args, **kw):
try:
callable_(*args, **kw)
success = False
except except_cls:
success = True
# assert outside the block so it works for AssertionError too !
assert success, "Callable did not raise an exception"
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
try:
callable_(*args, **kwargs)
assert False, "Callable did not raise an exception"
except except_cls, e:
assert re.search(msg, unicode(e), re.UNICODE), u"%r !~ %s" % (msg, e)
print unicode(e).encode('utf-8')
class AssertsCompiledSQL(object):
def assert_compile(self, clause, result, params=None,
checkparams=None, dialect=None,
checkpositional=None,
use_default_dialect=False,
allow_dialect_select=False):
if use_default_dialect:
dialect = default.DefaultDialect()
elif dialect == None and not allow_dialect_select:
dialect = getattr(self, '__dialect__', None)
if dialect == 'default':
dialect = default.DefaultDialect()
elif dialect is None:
dialect = config.db.dialect
elif isinstance(dialect, basestring):
dialect = create_engine("%s://" % dialect).dialect
kw = {}
if params is not None:
kw['column_keys'] = params.keys()
if isinstance(clause, orm.Query):
context = clause._compile_context()
context.statement.use_labels = True
clause = context.statement
c = clause.compile(dialect=dialect, **kw)
param_str = repr(getattr(c, 'params', {}))
# Py3K
#param_str = param_str.encode('utf-8').decode('ascii', 'ignore')
print "\nSQL String:\n" + str(c) + param_str
cc = re.sub(r'[\n\t]', '', str(c))
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
if checkparams is not None:
eq_(c.construct_params(params), checkparams)
if checkpositional is not None:
p = c.construct_params(params)
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
class ComparesTables(object):
def assert_tables_equal(self, table, reflected_table, strict_types=False):
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
eq_(c.name, reflected_c.name)
assert reflected_c is reflected_table.c[c.name]
eq_(c.primary_key, reflected_c.primary_key)
eq_(c.nullable, reflected_c.nullable)
if strict_types:
msg = "Type '%s' doesn't correspond to type '%s'"
assert type(reflected_c.type) is type(c.type), \
msg % (reflected_c.type, c.type)
else:
self.assert_types_base(reflected_c, c)
if isinstance(c.type, sqltypes.String):
eq_(c.type.length, reflected_c.type.length)
eq_(
set([f.column.name for f in c.foreign_keys]),
set([f.column.name for f in reflected_c.foreign_keys])
)
if c.server_default:
assert isinstance(reflected_c.server_default,
schema.FetchedValue)
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name] is not None
def assert_types_base(self, c1, c2):
assert c1.type._compare_type_affinity(c2.type),\
"On column %r, type '%s' doesn't correspond to type '%s'" % \
(c1.name, c1.type, c2.type)
class AssertsExecutionResults(object):
def assert_result(self, result, class_, *objects):
result = list(result)
print repr(result)
self.assert_list(result, class_, objects)
def assert_list(self, result, class_, list):
self.assert_(len(result) == len(list),
"result list is not the same size as test list, " +
"for class " + class_.__name__)
for i in range(0, len(list)):
self.assert_row(class_, result[i], list[i])
def assert_row(self, class_, rowobj, desc):
self.assert_(rowobj.__class__ is class_,
"item class is not " + repr(class_))
for key, value in desc.iteritems():
if isinstance(value, tuple):
if isinstance(value[1], list):
self.assert_list(getattr(rowobj, key), value[0], value[1])
else:
self.assert_row(value[0], getattr(rowobj, key), value[1])
else:
self.assert_(getattr(rowobj, key) == value,
"attribute %s value %s does not match %s" % (
key, getattr(rowobj, key), value))
def assert_unordered_result(self, result, cls, *expected):
"""As assert_result, but the order of objects is not considered.
The algorithm is very expensive but not a big deal for the small
numbers of rows that the test suite manipulates.
"""
class immutabledict(dict):
def __hash__(self):
return id(self)
found = util.IdentitySet(result)
expected = set([immutabledict(e) for e in expected])
for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
fail('Unexpected type "%s", expected "%s"' % (
type(wrong).__name__, cls.__name__))
if len(found) != len(expected):
fail('Unexpected object count "%s", expected "%s"' % (
len(found), len(expected)))
NOVALUE = object()
def _compare_item(obj, spec):
for key, value in spec.iteritems():
if isinstance(value, tuple):
try:
self.assert_unordered_result(
getattr(obj, key), value[0], *value[1])
except AssertionError:
return False
else:
if getattr(obj, key, NOVALUE) != value:
return False
return True
for expected_item in expected:
for found_item in found:
if _compare_item(found_item, expected_item):
found.remove(found_item)
break
else:
fail(
"Expected %s instance with attributes %s not found." % (
cls.__name__, repr(expected_item)))
return True
def assert_sql_execution(self, db, callable_, *rules):
assertsql.asserter.add_rules(rules)
try:
callable_()
assertsql.asserter.statement_complete()
finally:
assertsql.asserter.clear_rules()
def assert_sql(self, db, callable_, list_, with_sequences=None):
if with_sequences is not None and config.db.dialect.supports_sequences:
rules = with_sequences
else:
rules = list_
newrules = []
for rule in rules:
if isinstance(rule, dict):
newrule = assertsql.AllOf(*[
assertsql.ExactSQL(k, v) for k, v in rule.iteritems()
])
else:
newrule = assertsql.ExactSQL(*rule)
newrules.append(newrule)
self.assert_sql_execution(db, callable_, *newrules)
def assert_sql_count(self, db, callable_, count):
self.assert_sql_execution(
db, callable_, assertsql.CountStatements(count))
@contextlib.contextmanager
def assert_execution(self, *rules):
assertsql.asserter.add_rules(rules)
try:
yield
assertsql.asserter.statement_complete()
finally:
assertsql.asserter.clear_rules()
def assert_statement_count(self, count):
return self.assert_execution(assertsql.CountStatements(count))

View File

@@ -0,0 +1,328 @@
from ..engine.default import DefaultDialect
from .. import util
import re
class AssertRule(object):
def process_execute(self, clauseelement, *multiparams, **params):
pass
def process_cursor_execute(self, statement, parameters, context,
executemany):
pass
def is_consumed(self):
"""Return True if this rule has been consumed, False if not.
Should raise an AssertionError if this rule's condition has
definitely failed.
"""
raise NotImplementedError()
def rule_passed(self):
"""Return True if the last test of this rule passed, False if
failed, None if no test was applied."""
raise NotImplementedError()
def consume_final(self):
"""Return True if this rule has been consumed.
Should raise an AssertionError if this rule's condition has not
been consumed or has failed.
"""
if self._result is None:
assert False, 'Rule has not been consumed'
return self.is_consumed()
class SQLMatchRule(AssertRule):
def __init__(self):
self._result = None
self._errmsg = ""
def rule_passed(self):
return self._result
def is_consumed(self):
if self._result is None:
return False
assert self._result, self._errmsg
return True
class ExactSQL(SQLMatchRule):
def __init__(self, sql, params=None):
SQLMatchRule.__init__(self)
self.sql = sql
self.params = params
def process_cursor_execute(self, statement, parameters, context,
executemany):
if not context:
return
_received_statement = \
_process_engine_statement(context.unicode_statement,
context)
_received_parameters = context.compiled_parameters
# TODO: remove this step once all unit tests are migrated, as
# ExactSQL should really be *exact* SQL
sql = _process_assertion_statement(self.sql, context)
equivalent = _received_statement == sql
if self.params:
if util.callable(self.params):
params = self.params(context)
else:
params = self.params
if not isinstance(params, list):
params = [params]
equivalent = equivalent and params \
== context.compiled_parameters
else:
params = {}
self._result = equivalent
if not self._result:
self._errmsg = \
'Testing for exact statement %r exact params %r, '\
'received %r with params %r' % (sql, params,
_received_statement, _received_parameters)
class RegexSQL(SQLMatchRule):
def __init__(self, regex, params=None):
SQLMatchRule.__init__(self)
self.regex = re.compile(regex)
self.orig_regex = regex
self.params = params
def process_cursor_execute(self, statement, parameters, context,
executemany):
if not context:
return
_received_statement = \
_process_engine_statement(context.unicode_statement,
context)
_received_parameters = context.compiled_parameters
equivalent = bool(self.regex.match(_received_statement))
if self.params:
if util.callable(self.params):
params = self.params(context)
else:
params = self.params
if not isinstance(params, list):
params = [params]
# do a positive compare only
for param, received in zip(params, _received_parameters):
for k, v in param.iteritems():
if k not in received or received[k] != v:
equivalent = False
break
else:
params = {}
self._result = equivalent
if not self._result:
self._errmsg = \
'Testing for regex %r partial params %r, received %r '\
'with params %r' % (self.orig_regex, params,
_received_statement,
_received_parameters)
class CompiledSQL(SQLMatchRule):
def __init__(self, statement, params=None):
SQLMatchRule.__init__(self)
self.statement = statement
self.params = params
def process_cursor_execute(self, statement, parameters, context,
executemany):
if not context:
return
from sqlalchemy.schema import _DDLCompiles
_received_parameters = list(context.compiled_parameters)
# recompile from the context, using the default dialect
if isinstance(context.compiled.statement, _DDLCompiles):
compiled = \
context.compiled.statement.compile(dialect=DefaultDialect())
else:
compiled = \
context.compiled.statement.compile(dialect=DefaultDialect(),
column_keys=context.compiled.column_keys)
_received_statement = re.sub(r'[\n\t]', '', str(compiled))
equivalent = self.statement == _received_statement
if self.params:
if util.callable(self.params):
params = self.params(context)
else:
params = self.params
if not isinstance(params, list):
params = [params]
else:
params = list(params)
all_params = list(params)
all_received = list(_received_parameters)
while params:
param = dict(params.pop(0))
for k, v in context.compiled.params.iteritems():
param.setdefault(k, v)
if param not in _received_parameters:
equivalent = False
break
else:
_received_parameters.remove(param)
if _received_parameters:
equivalent = False
else:
params = {}
all_params = {}
all_received = []
self._result = equivalent
if not self._result:
print 'Testing for compiled statement %r partial params '\
'%r, received %r with params %r' % (self.statement,
all_params, _received_statement, all_received)
self._errmsg = \
'Testing for compiled statement %r partial params %r, '\
'received %r with params %r' % (self.statement,
all_params, _received_statement, all_received)
# print self._errmsg
class CountStatements(AssertRule):
def __init__(self, count):
self.count = count
self._statement_count = 0
def process_execute(self, clauseelement, *multiparams, **params):
self._statement_count += 1
def process_cursor_execute(self, statement, parameters, context,
executemany):
pass
def is_consumed(self):
return False
def consume_final(self):
assert self.count == self._statement_count, \
'desired statement count %d does not match %d' \
% (self.count, self._statement_count)
return True
class AllOf(AssertRule):
def __init__(self, *rules):
self.rules = set(rules)
def process_execute(self, clauseelement, *multiparams, **params):
for rule in self.rules:
rule.process_execute(clauseelement, *multiparams, **params)
def process_cursor_execute(self, statement, parameters, context,
executemany):
for rule in self.rules:
rule.process_cursor_execute(statement, parameters, context,
executemany)
def is_consumed(self):
if not self.rules:
return True
for rule in list(self.rules):
if rule.rule_passed(): # a rule passed, move on
self.rules.remove(rule)
return len(self.rules) == 0
assert False, 'No assertion rules were satisfied for statement'
def consume_final(self):
return len(self.rules) == 0
def _process_engine_statement(query, context):
if util.jython:
# oracle+zxjdbc passes a PyStatement when returning into
query = unicode(query)
if context.engine.name == 'mssql' \
and query.endswith('; select scope_identity()'):
query = query[:-25]
query = re.sub(r'\n', '', query)
return query
def _process_assertion_statement(query, context):
paramstyle = context.dialect.paramstyle
if paramstyle == 'named':
pass
elif paramstyle == 'pyformat':
query = re.sub(r':([\w_]+)', r"%(\1)s", query)
else:
# positional params
repl = None
if paramstyle == 'qmark':
repl = "?"
elif paramstyle == 'format':
repl = r"%s"
elif paramstyle == 'numeric':
repl = None
query = re.sub(r':([\w_]+)', repl, query)
return query
class SQLAssert(object):
rules = None
def add_rules(self, rules):
self.rules = list(rules)
def statement_complete(self):
for rule in self.rules:
if not rule.consume_final():
assert False, \
'All statements are complete, but pending '\
'assertion rules remain'
def clear_rules(self):
del self.rules
def execute(self, conn, clauseelement, multiparams, params, result):
if self.rules is not None:
if not self.rules:
assert False, \
'All rules have been exhausted, but further '\
'statements remain'
rule = self.rules[0]
rule.process_execute(clauseelement, *multiparams, **params)
if rule.is_consumed():
self.rules.pop(0)
def cursor_execute(self, conn, cursor, statement, parameters,
context, executemany):
if self.rules:
rule = self.rules[0]
rule.process_cursor_execute(statement, parameters, context,
executemany)
asserter = SQLAssert()

View File

@@ -0,0 +1,2 @@
requirements = None
db = None

View File

@@ -0,0 +1,456 @@
from __future__ import absolute_import
import types
import weakref
from collections import deque
from . import config
from .util import decorator
from .. import event, pool
import re
import warnings
class ConnectionKiller(object):
def __init__(self):
self.proxy_refs = weakref.WeakKeyDictionary()
self.testing_engines = weakref.WeakKeyDictionary()
self.conns = set()
def add_engine(self, engine):
self.testing_engines[engine] = True
def connect(self, dbapi_conn, con_record):
self.conns.add((dbapi_conn, con_record))
def checkout(self, dbapi_con, con_record, con_proxy):
self.proxy_refs[con_proxy] = True
def _safe(self, fn):
try:
fn()
except (SystemExit, KeyboardInterrupt):
raise
except Exception, e:
warnings.warn(
"testing_reaper couldn't "
"rollback/close connection: %s" % e)
def rollback_all(self):
for rec in self.proxy_refs.keys():
if rec is not None and rec.is_valid:
self._safe(rec.rollback)
def close_all(self):
for rec in self.proxy_refs.keys():
if rec is not None:
self._safe(rec._close)
def _after_test_ctx(self):
pass
# this can cause a deadlock with pg8000 - pg8000 acquires
# prepared statment lock inside of rollback() - if async gc
# is collecting in finalize_fairy, deadlock.
# not sure if this should be if pypy/jython only.
# note that firebird/fdb definitely needs this though
for conn, rec in self.conns:
self._safe(conn.rollback)
def _stop_test_ctx(self):
if config.options.low_connections:
self._stop_test_ctx_minimal()
else:
self._stop_test_ctx_aggressive()
def _stop_test_ctx_minimal(self):
self.close_all()
self.conns = set()
for rec in self.testing_engines.keys():
if rec is not config.db:
rec.dispose()
def _stop_test_ctx_aggressive(self):
self.close_all()
for conn, rec in self.conns:
self._safe(conn.close)
rec.connection = None
self.conns = set()
for rec in self.testing_engines.keys():
rec.dispose()
def assert_all_closed(self):
for rec in self.proxy_refs:
if rec.is_valid:
assert False
testing_reaper = ConnectionKiller()
def drop_all_tables(metadata, bind):
testing_reaper.close_all()
if hasattr(bind, 'close'):
bind.close()
metadata.drop_all(bind)
@decorator
def assert_conns_closed(fn, *args, **kw):
try:
fn(*args, **kw)
finally:
testing_reaper.assert_all_closed()
@decorator
def rollback_open_connections(fn, *args, **kw):
"""Decorator that rolls back all open connections after fn execution."""
try:
fn(*args, **kw)
finally:
testing_reaper.rollback_all()
@decorator
def close_first(fn, *args, **kw):
"""Decorator that closes all connections before fn execution."""
testing_reaper.close_all()
fn(*args, **kw)
@decorator
def close_open_connections(fn, *args, **kw):
"""Decorator that closes all connections after fn execution."""
try:
fn(*args, **kw)
finally:
testing_reaper.close_all()
def all_dialects(exclude=None):
import sqlalchemy.databases as d
for name in d.__all__:
# TEMPORARY
if exclude and name in exclude:
continue
mod = getattr(d, name, None)
if not mod:
mod = getattr(__import__(
'sqlalchemy.databases.%s' % name).databases, name)
yield mod.dialect()
class ReconnectFixture(object):
def __init__(self, dbapi):
self.dbapi = dbapi
self.connections = []
def __getattr__(self, key):
return getattr(self.dbapi, key)
def connect(self, *args, **kwargs):
conn = self.dbapi.connect(*args, **kwargs)
self.connections.append(conn)
return conn
def _safe(self, fn):
try:
fn()
except (SystemExit, KeyboardInterrupt):
raise
except Exception, e:
warnings.warn(
"ReconnectFixture couldn't "
"close connection: %s" % e)
def shutdown(self):
# TODO: this doesn't cover all cases
# as nicely as we'd like, namely MySQLdb.
# would need to implement R. Brewer's
# proxy server idea to get better
# coverage.
for c in list(self.connections):
self._safe(c.close)
self.connections = []
def reconnecting_engine(url=None, options=None):
url = url or config.db_url
dbapi = config.db.dialect.dbapi
if not options:
options = {}
options['module'] = ReconnectFixture(dbapi)
engine = testing_engine(url, options)
_dispose = engine.dispose
def dispose():
engine.dialect.dbapi.shutdown()
_dispose()
engine.test_shutdown = engine.dialect.dbapi.shutdown
engine.dispose = dispose
return engine
def testing_engine(url=None, options=None):
"""Produce an engine configured by --options with optional overrides."""
from sqlalchemy import create_engine
from .assertsql import asserter
if not options:
use_reaper = True
else:
use_reaper = options.pop('use_reaper', True)
url = url or config.db_url
if options is None:
options = config.db_opts
engine = create_engine(url, **options)
if isinstance(engine.pool, pool.QueuePool):
engine.pool._timeout = 0
engine.pool._max_overflow = 0
event.listen(engine, 'after_execute', asserter.execute)
event.listen(engine, 'after_cursor_execute', asserter.cursor_execute)
if use_reaper:
event.listen(engine.pool, 'connect', testing_reaper.connect)
event.listen(engine.pool, 'checkout', testing_reaper.checkout)
testing_reaper.add_engine(engine)
return engine
def utf8_engine(url=None, options=None):
"""Hook for dialects or drivers that don't handle utf8 by default."""
from sqlalchemy.engine import url as engine_url
if config.db.dialect.name == 'mysql' and \
config.db.driver in ['mysqldb', 'pymysql', 'cymysql']:
# note 1.2.1.gamma.6 or greater of MySQLdb
# needed here
url = url or config.db_url
url = engine_url.make_url(url)
url.query['charset'] = 'utf8'
url.query['use_unicode'] = '0'
url = str(url)
return testing_engine(url, options)
def mock_engine(dialect_name=None):
"""Provides a mocking engine based on the current testing.db.
This is normally used to test DDL generation flow as emitted
by an Engine.
It should not be used in other cases, as assert_compile() and
assert_sql_execution() are much better choices with fewer
moving parts.
"""
from sqlalchemy import create_engine
if not dialect_name:
dialect_name = config.db.name
buffer = []
def executor(sql, *a, **kw):
buffer.append(sql)
def assert_sql(stmts):
recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
assert recv == stmts, recv
def print_sql():
d = engine.dialect
return "\n".join(
str(s.compile(dialect=d))
for s in engine.mock
)
engine = create_engine(dialect_name + '://',
strategy='mock', executor=executor)
assert not hasattr(engine, 'mock')
engine.mock = buffer
engine.assert_sql = assert_sql
engine.print_sql = print_sql
return engine
class DBAPIProxyCursor(object):
"""Proxy a DBAPI cursor.
Tests can provide subclasses of this to intercept
DBAPI-level cursor operations.
"""
def __init__(self, engine, conn):
self.engine = engine
self.connection = conn
self.cursor = conn.cursor()
def execute(self, stmt, parameters=None, **kw):
if parameters:
return self.cursor.execute(stmt, parameters, **kw)
else:
return self.cursor.execute(stmt, **kw)
def executemany(self, stmt, params, **kw):
return self.cursor.executemany(stmt, params, **kw)
def __getattr__(self, key):
return getattr(self.cursor, key)
class DBAPIProxyConnection(object):
"""Proxy a DBAPI connection.
Tests can provide subclasses of this to intercept
DBAPI-level connection operations.
"""
def __init__(self, engine, cursor_cls):
self.conn = self._sqla_unwrap = engine.pool._creator()
self.engine = engine
self.cursor_cls = cursor_cls
def cursor(self):
return self.cursor_cls(self.engine, self.conn)
def close(self):
self.conn.close()
def __getattr__(self, key):
return getattr(self.conn, key)
def proxying_engine(conn_cls=DBAPIProxyConnection,
cursor_cls=DBAPIProxyCursor):
"""Produce an engine that provides proxy hooks for
common methods.
"""
def mock_conn():
return conn_cls(config.db, cursor_cls)
return testing_engine(options={'creator': mock_conn})
class ReplayableSession(object):
"""A simple record/playback tool.
This is *not* a mock testing class. It only records a session for later
playback and makes no assertions on call consistency whatsoever. It's
unlikely to be suitable for anything other than DB-API recording.
"""
Callable = object()
NoAttribute = object()
# Py3K
#Natives = set([getattr(types, t)
# for t in dir(types) if not t.startswith('_')]). \
# union([type(t) if not isinstance(t, type)
# else t for t in __builtins__.values()]).\
# difference([getattr(types, t)
# for t in ('FunctionType', 'BuiltinFunctionType',
# 'MethodType', 'BuiltinMethodType',
# 'LambdaType', )])
# Py2K
Natives = set([getattr(types, t)
for t in dir(types) if not t.startswith('_')]). \
difference([getattr(types, t)
for t in ('FunctionType', 'BuiltinFunctionType',
'MethodType', 'BuiltinMethodType',
'LambdaType', 'UnboundMethodType',)])
# end Py2K
def __init__(self):
self.buffer = deque()
def recorder(self, base):
return self.Recorder(self.buffer, base)
def player(self):
return self.Player(self.buffer)
class Recorder(object):
def __init__(self, buffer, subject):
self._buffer = buffer
self._subject = subject
def __call__(self, *args, **kw):
subject, buffer = [object.__getattribute__(self, x)
for x in ('_subject', '_buffer')]
result = subject(*args, **kw)
if type(result) not in ReplayableSession.Natives:
buffer.append(ReplayableSession.Callable)
return type(self)(buffer, result)
else:
buffer.append(result)
return result
@property
def _sqla_unwrap(self):
return self._subject
def __getattribute__(self, key):
try:
return object.__getattribute__(self, key)
except AttributeError:
pass
subject, buffer = [object.__getattribute__(self, x)
for x in ('_subject', '_buffer')]
try:
result = type(subject).__getattribute__(subject, key)
except AttributeError:
buffer.append(ReplayableSession.NoAttribute)
raise
else:
if type(result) not in ReplayableSession.Natives:
buffer.append(ReplayableSession.Callable)
return type(self)(buffer, result)
else:
buffer.append(result)
return result
class Player(object):
def __init__(self, buffer):
self._buffer = buffer
def __call__(self, *args, **kw):
buffer = object.__getattribute__(self, '_buffer')
result = buffer.popleft()
if result is ReplayableSession.Callable:
return self
else:
return result
@property
def _sqla_unwrap(self):
return None
def __getattribute__(self, key):
try:
return object.__getattribute__(self, key)
except AttributeError:
pass
buffer = object.__getattribute__(self, '_buffer')
result = buffer.popleft()
if result is ReplayableSession.Callable:
return self
elif result is ReplayableSession.NoAttribute:
raise AttributeError(key)
else:
return result

View File

@@ -0,0 +1,89 @@
import sqlalchemy as sa
from sqlalchemy import exc as sa_exc
_repr_stack = set()
class BasicEntity(object):
def __init__(self, **kw):
for key, value in kw.iteritems():
setattr(self, key, value)
def __repr__(self):
if id(self) in _repr_stack:
return object.__repr__(self)
_repr_stack.add(id(self))
try:
return "%s(%s)" % (
(self.__class__.__name__),
', '.join(["%s=%r" % (key, getattr(self, key))
for key in sorted(self.__dict__.keys())
if not key.startswith('_')]))
finally:
_repr_stack.remove(id(self))
_recursion_stack = set()
class ComparableEntity(BasicEntity):
def __hash__(self):
return hash(self.__class__)
def __ne__(self, other):
return not self.__eq__(other)
def __eq__(self, other):
"""'Deep, sparse compare.
Deeply compare two entities, following the non-None attributes of the
non-persisted object, if possible.
"""
if other is self:
return True
elif not self.__class__ == other.__class__:
return False
if id(self) in _recursion_stack:
return True
_recursion_stack.add(id(self))
try:
# pick the entity thats not SA persisted as the source
try:
self_key = sa.orm.attributes.instance_state(self).key
except sa.orm.exc.NO_STATE:
self_key = None
if other is None:
a = self
b = other
elif self_key is not None:
a = other
b = self
else:
a = self
b = other
for attr in a.__dict__.keys():
if attr.startswith('_'):
continue
value = getattr(a, attr)
try:
# handle lazy loader errors
battr = getattr(b, attr)
except (AttributeError, sa_exc.UnboundExecutionError):
return False
if hasattr(value, '__iter__'):
if list(value) != list(battr):
return False
else:
if value is not None and value != battr:
return False
return True
finally:
_recursion_stack.remove(id(self))

View File

@@ -0,0 +1,334 @@
from __future__ import with_statement
import operator
from nose import SkipTest
from ..util import decorator
from . import config
from .. import util
import contextlib
class skip_if(object):
def __init__(self, predicate, reason=None):
self.predicate = _as_predicate(predicate)
self.reason = reason
_fails_on = None
@property
def enabled(self):
return not self.predicate()
@contextlib.contextmanager
def fail_if(self, name='block'):
try:
yield
except Exception, ex:
if self.predicate():
print ("%s failed as expected (%s): %s " % (
name, self.predicate, str(ex)))
else:
raise
else:
if self.predicate():
raise AssertionError(
"Unexpected success for '%s' (%s)" %
(name, self.predicate))
def __call__(self, fn):
@decorator
def decorate(fn, *args, **kw):
if self.predicate():
if self.reason:
msg = "'%s' : %s" % (
fn.__name__,
self.reason
)
else:
msg = "'%s': %s" % (
fn.__name__, self.predicate
)
raise SkipTest(msg)
else:
if self._fails_on:
with self._fails_on.fail_if(name=fn.__name__):
return fn(*args, **kw)
else:
return fn(*args, **kw)
return decorate(fn)
def fails_on(self, other, reason=None):
self._fails_on = skip_if(other, reason)
return self
class fails_if(skip_if):
def __call__(self, fn):
@decorator
def decorate(fn, *args, **kw):
with self.fail_if(name=fn.__name__):
return fn(*args, **kw)
return decorate(fn)
def only_if(predicate, reason=None):
predicate = _as_predicate(predicate)
return skip_if(NotPredicate(predicate), reason)
def succeeds_if(predicate, reason=None):
predicate = _as_predicate(predicate)
return fails_if(NotPredicate(predicate), reason)
class Predicate(object):
@classmethod
def as_predicate(cls, predicate):
if isinstance(predicate, skip_if):
return predicate.predicate
elif isinstance(predicate, Predicate):
return predicate
elif isinstance(predicate, list):
return OrPredicate([cls.as_predicate(pred) for pred in predicate])
elif isinstance(predicate, tuple):
return SpecPredicate(*predicate)
elif isinstance(predicate, basestring):
return SpecPredicate(predicate, None, None)
elif util.callable(predicate):
return LambdaPredicate(predicate)
else:
assert False, "unknown predicate type: %s" % predicate
class BooleanPredicate(Predicate):
def __init__(self, value, description=None):
self.value = value
self.description = description or "boolean %s" % value
def __call__(self):
return self.value
def _as_string(self, negate=False):
if negate:
return "not " + self.description
else:
return self.description
def __str__(self):
return self._as_string()
class SpecPredicate(Predicate):
def __init__(self, db, op=None, spec=None, description=None):
self.db = db
self.op = op
self.spec = spec
self.description = description
_ops = {
'<': operator.lt,
'>': operator.gt,
'==': operator.eq,
'!=': operator.ne,
'<=': operator.le,
'>=': operator.ge,
'in': operator.contains,
'between': lambda val, pair: val >= pair[0] and val <= pair[1],
}
def __call__(self, engine=None):
if engine is None:
engine = config.db
if "+" in self.db:
dialect, driver = self.db.split('+')
else:
dialect, driver = self.db, None
if dialect and engine.name != dialect:
return False
if driver is not None and engine.driver != driver:
return False
if self.op is not None:
assert driver is None, "DBAPI version specs not supported yet"
version = _server_version(engine)
oper = hasattr(self.op, '__call__') and self.op \
or self._ops[self.op]
return oper(version, self.spec)
else:
return True
def _as_string(self, negate=False):
if self.description is not None:
return self.description
elif self.op is None:
if negate:
return "not %s" % self.db
else:
return "%s" % self.db
else:
if negate:
return "not %s %s %s" % (
self.db,
self.op,
self.spec
)
else:
return "%s %s %s" % (
self.db,
self.op,
self.spec
)
def __str__(self):
return self._as_string()
class LambdaPredicate(Predicate):
def __init__(self, lambda_, description=None, args=None, kw=None):
self.lambda_ = lambda_
self.args = args or ()
self.kw = kw or {}
if description:
self.description = description
elif lambda_.__doc__:
self.description = lambda_.__doc__
else:
self.description = "custom function"
def __call__(self):
return self.lambda_(*self.args, **self.kw)
def _as_string(self, negate=False):
if negate:
return "not " + self.description
else:
return self.description
def __str__(self):
return self._as_string()
class NotPredicate(Predicate):
def __init__(self, predicate):
self.predicate = predicate
def __call__(self, *arg, **kw):
return not self.predicate(*arg, **kw)
def __str__(self):
return self.predicate._as_string(True)
class OrPredicate(Predicate):
def __init__(self, predicates, description=None):
self.predicates = predicates
self.description = description
def __call__(self, *arg, **kw):
for pred in self.predicates:
if pred(*arg, **kw):
self._str = pred
return True
return False
_str = None
def _eval_str(self, negate=False):
if self._str is None:
if negate:
conjunction = " and "
else:
conjunction = " or "
return conjunction.join(p._as_string(negate=negate)
for p in self.predicates)
else:
return self._str._as_string(negate=negate)
def _negation_str(self):
if self.description is not None:
return "Not " + (self.description % {"spec": self._str})
else:
return self._eval_str(negate=True)
def _as_string(self, negate=False):
if negate:
return self._negation_str()
else:
if self.description is not None:
return self.description % {"spec": self._str}
else:
return self._eval_str()
def __str__(self):
return self._as_string()
_as_predicate = Predicate.as_predicate
def _is_excluded(db, op, spec):
return SpecPredicate(db, op, spec)()
def _server_version(engine):
"""Return a server_version_info tuple."""
# force metadata to be retrieved
conn = engine.connect()
version = getattr(engine.dialect, 'server_version_info', ())
conn.close()
return version
def db_spec(*dbs):
return OrPredicate(
Predicate.as_predicate(db) for db in dbs
)
def open():
return skip_if(BooleanPredicate(False, "mark as execute"))
def closed():
return skip_if(BooleanPredicate(True, "marked as skip"))
@decorator
def future(fn, *args, **kw):
return fails_if(LambdaPredicate(fn, *args, **kw), "Future feature")
def fails_on(db, reason=None):
return fails_if(SpecPredicate(db), reason)
def fails_on_everything_except(*dbs):
return succeeds_if(
OrPredicate([
SpecPredicate(db) for db in dbs
])
)
def skip(db, reason=None):
return skip_if(SpecPredicate(db), reason)
def only_on(dbs, reason=None):
return only_if(
OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)])
)
def exclude(db, op, spec, reason=None):
return skip_if(SpecPredicate(db, op, spec), reason)
def against(*queries):
return OrPredicate([
Predicate.as_predicate(query)
for query in queries
])()

View File

@@ -0,0 +1,344 @@
from . import config
from . import assertions, schema
from .util import adict
from .engines import drop_all_tables
from .entities import BasicEntity, ComparableEntity
import sys
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
class TestBase(object):
# A sequence of database names to always run, regardless of the
# constraints below.
__whitelist__ = ()
# A sequence of requirement names matching testing.requires decorators
__requires__ = ()
# A sequence of dialect names to exclude from the test class.
__unsupported_on__ = ()
# If present, test class is only runnable for the *single* specified
# dialect. If you need multiple, use __unsupported_on__ and invert.
__only_on__ = None
# A sequence of no-arg callables. If any are True, the entire testcase is
# skipped.
__skip_if__ = None
def assert_(self, val, msg=None):
assert val, msg
class TablesTest(TestBase):
# 'once', None
run_setup_bind = 'once'
# 'once', 'each', None
run_define_tables = 'once'
# 'once', 'each', None
run_create_tables = 'once'
# 'once', 'each', None
run_inserts = 'each'
# 'each', None
run_deletes = 'each'
# 'once', None
run_dispose_bind = None
bind = None
metadata = None
tables = None
other = None
@classmethod
def setup_class(cls):
cls._init_class()
cls._setup_once_tables()
cls._setup_once_inserts()
@classmethod
def _init_class(cls):
if cls.run_define_tables == 'each':
if cls.run_create_tables == 'once':
cls.run_create_tables = 'each'
assert cls.run_inserts in ('each', None)
if cls.other is None:
cls.other = adict()
if cls.tables is None:
cls.tables = adict()
if cls.bind is None:
setattr(cls, 'bind', cls.setup_bind())
if cls.metadata is None:
setattr(cls, 'metadata', sa.MetaData())
if cls.metadata.bind is None:
cls.metadata.bind = cls.bind
@classmethod
def _setup_once_inserts(cls):
if cls.run_inserts == 'once':
cls._load_fixtures()
cls.insert_data()
@classmethod
def _setup_once_tables(cls):
if cls.run_define_tables == 'once':
cls.define_tables(cls.metadata)
if cls.run_create_tables == 'once':
cls.metadata.create_all(cls.bind)
cls.tables.update(cls.metadata.tables)
def _setup_each_tables(self):
if self.run_define_tables == 'each':
self.tables.clear()
if self.run_create_tables == 'each':
drop_all_tables(self.metadata, self.bind)
self.metadata.clear()
self.define_tables(self.metadata)
if self.run_create_tables == 'each':
self.metadata.create_all(self.bind)
self.tables.update(self.metadata.tables)
elif self.run_create_tables == 'each':
drop_all_tables(self.metadata, self.bind)
self.metadata.create_all(self.bind)
def _setup_each_inserts(self):
if self.run_inserts == 'each':
self._load_fixtures()
self.insert_data()
def _teardown_each_tables(self):
# no need to run deletes if tables are recreated on setup
if self.run_define_tables != 'each' and self.run_deletes == 'each':
for table in reversed(self.metadata.sorted_tables):
try:
table.delete().execute().close()
except sa.exc.DBAPIError, ex:
print >> sys.stderr, "Error emptying table %s: %r" % (
table, ex)
def setup(self):
self._setup_each_tables()
self._setup_each_inserts()
def teardown(self):
self._teardown_each_tables()
@classmethod
def _teardown_once_metadata_bind(cls):
if cls.run_create_tables:
drop_all_tables(cls.metadata, cls.bind)
if cls.run_dispose_bind == 'once':
cls.dispose_bind(cls.bind)
cls.metadata.bind = None
if cls.run_setup_bind is not None:
cls.bind = None
@classmethod
def teardown_class(cls):
cls._teardown_once_metadata_bind()
@classmethod
def setup_bind(cls):
return config.db
@classmethod
def dispose_bind(cls, bind):
if hasattr(bind, 'dispose'):
bind.dispose()
elif hasattr(bind, 'close'):
bind.close()
@classmethod
def define_tables(cls, metadata):
pass
@classmethod
def fixtures(cls):
return {}
@classmethod
def insert_data(cls):
pass
def sql_count_(self, count, fn):
self.assert_sql_count(self.bind, fn, count)
def sql_eq_(self, callable_, statements, with_sequences=None):
self.assert_sql(self.bind,
callable_, statements, with_sequences)
@classmethod
def _load_fixtures(cls):
"""Insert rows as represented by the fixtures() method."""
headers, rows = {}, {}
for table, data in cls.fixtures().iteritems():
if len(data) < 2:
continue
if isinstance(table, basestring):
table = cls.tables[table]
headers[table] = data[0]
rows[table] = data[1:]
for table in cls.metadata.sorted_tables:
if table not in headers:
continue
cls.bind.execute(
table.insert(),
[dict(zip(headers[table], column_values))
for column_values in rows[table]])
class _ORMTest(object):
@classmethod
def teardown_class(cls):
sa.orm.session.Session.close_all()
sa.orm.clear_mappers()
class ORMTest(_ORMTest, TestBase):
pass
class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
# 'once', 'each', None
run_setup_classes = 'once'
# 'once', 'each', None
run_setup_mappers = 'each'
classes = None
@classmethod
def setup_class(cls):
cls._init_class()
if cls.classes is None:
cls.classes = adict()
cls._setup_once_tables()
cls._setup_once_classes()
cls._setup_once_mappers()
cls._setup_once_inserts()
@classmethod
def teardown_class(cls):
cls._teardown_once_class()
cls._teardown_once_metadata_bind()
def setup(self):
self._setup_each_tables()
self._setup_each_mappers()
self._setup_each_inserts()
def teardown(self):
sa.orm.session.Session.close_all()
self._teardown_each_mappers()
self._teardown_each_tables()
@classmethod
def _teardown_once_class(cls):
cls.classes.clear()
_ORMTest.teardown_class()
@classmethod
def _setup_once_classes(cls):
if cls.run_setup_classes == 'once':
cls._with_register_classes(cls.setup_classes)
@classmethod
def _setup_once_mappers(cls):
if cls.run_setup_mappers == 'once':
cls._with_register_classes(cls.setup_mappers)
def _setup_each_mappers(self):
if self.run_setup_mappers == 'each':
self._with_register_classes(self.setup_mappers)
@classmethod
def _with_register_classes(cls, fn):
"""Run a setup method, framing the operation with a Base class
that will catch new subclasses to be established within
the "classes" registry.
"""
cls_registry = cls.classes
class FindFixture(type):
def __init__(cls, classname, bases, dict_):
cls_registry[classname] = cls
return type.__init__(cls, classname, bases, dict_)
class _Base(object):
__metaclass__ = FindFixture
class Basic(BasicEntity, _Base):
pass
class Comparable(ComparableEntity, _Base):
pass
cls.Basic = Basic
cls.Comparable = Comparable
fn()
def _teardown_each_mappers(self):
# some tests create mappers in the test bodies
# and will define setup_mappers as None -
# clear mappers in any case
if self.run_setup_mappers != 'once':
sa.orm.clear_mappers()
@classmethod
def setup_classes(cls):
pass
@classmethod
def setup_mappers(cls):
pass
class DeclarativeMappedTest(MappedTest):
run_setup_classes = 'once'
run_setup_mappers = 'once'
@classmethod
def _setup_once_tables(cls):
pass
@classmethod
def _with_register_classes(cls, fn):
cls_registry = cls.classes
class FindFixtureDeclarative(DeclarativeMeta):
def __init__(cls, classname, bases, dict_):
cls_registry[classname] = cls
return DeclarativeMeta.__init__(
cls, classname, bases, dict_)
class DeclarativeBasic(object):
__table_cls__ = schema.Table
_DeclBase = declarative_base(metadata=cls.metadata,
metaclass=FindFixtureDeclarative,
cls=DeclarativeBasic)
cls.DeclarativeBasic = _DeclBase
fn()
if cls.metadata.tables:
cls.metadata.create_all(config.db)

View File

@@ -0,0 +1,15 @@
"""Import stub for mock library.
"""
from __future__ import absolute_import
from ..util import py33
if py33:
from unittest.mock import MagicMock, Mock, call
else:
try:
from mock import MagicMock, Mock, call
except ImportError:
raise ImportError(
"SQLAlchemy's test suite requires the "
"'mock' library as of 0.8.2.")

View File

@@ -0,0 +1,136 @@
"""Classes used in pickling tests, need to be at the module level for
unpickling.
"""
from . import fixtures
class User(fixtures.ComparableEntity):
pass
class Order(fixtures.ComparableEntity):
pass
class Dingaling(fixtures.ComparableEntity):
pass
class EmailUser(User):
pass
class Address(fixtures.ComparableEntity):
pass
# TODO: these are kind of arbitrary....
class Child1(fixtures.ComparableEntity):
pass
class Child2(fixtures.ComparableEntity):
pass
class Parent(fixtures.ComparableEntity):
pass
class Screen(object):
def __init__(self, obj, parent=None):
self.obj = obj
self.parent = parent
class Foo(object):
def __init__(self, moredata):
self.data = 'im data'
self.stuff = 'im stuff'
self.moredata = moredata
__hash__ = object.__hash__
def __eq__(self, other):
return other.data == self.data and \
other.stuff == self.stuff and \
other.moredata == self.moredata
class Bar(object):
def __init__(self, x, y):
self.x = x
self.y = y
__hash__ = object.__hash__
def __eq__(self, other):
return other.__class__ is self.__class__ and \
other.x == self.x and \
other.y == self.y
def __str__(self):
return "Bar(%d, %d)" % (self.x, self.y)
class OldSchool:
def __init__(self, x, y):
self.x = x
self.y = y
def __eq__(self, other):
return other.__class__ is self.__class__ and \
other.x == self.x and \
other.y == self.y
class OldSchoolWithoutCompare:
def __init__(self, x, y):
self.x = x
self.y = y
class BarWithoutCompare(object):
def __init__(self, x, y):
self.x = x
self.y = y
def __str__(self):
return "Bar(%d, %d)" % (self.x, self.y)
class NotComparable(object):
def __init__(self, data):
self.data = data
def __hash__(self):
return id(self)
def __eq__(self, other):
return NotImplemented
def __ne__(self, other):
return NotImplemented
class BrokenComparable(object):
def __init__(self, data):
self.data = data
def __hash__(self):
return id(self)
def __eq__(self, other):
raise NotImplementedError
def __ne__(self, other):
raise NotImplementedError

View File

@@ -0,0 +1,451 @@
"""Enhance nose with extra options and behaviors for running SQLAlchemy tests.
When running ./sqla_nose.py, this module is imported relative to the
"plugins" package as a top level package by the sqla_nose.py runner,
so that the plugin can be loaded with the rest of nose including the coverage
plugin before any of SQLAlchemy itself is imported, so that coverage works.
When third party libraries use this plugin, it can be imported
normally as "from sqlalchemy.testing.plugin import noseplugin".
"""
from __future__ import absolute_import
import os
import ConfigParser
from nose.plugins import Plugin
from nose import SkipTest
import time
import sys
import re
# late imports
fixtures = None
engines = None
exclusions = None
warnings = None
profiling = None
assertions = None
requirements = None
config = None
util = None
file_config = None
logging = None
db = None
db_label = None
db_url = None
db_opts = {}
options = None
_existing_engine = None
def _log(option, opt_str, value, parser):
global logging
if not logging:
import logging
logging.basicConfig()
if opt_str.endswith('-info'):
logging.getLogger(value).setLevel(logging.INFO)
elif opt_str.endswith('-debug'):
logging.getLogger(value).setLevel(logging.DEBUG)
def _list_dbs(*args):
print "Available --db options (use --dburi to override)"
for macro in sorted(file_config.options('db')):
print "%20s\t%s" % (macro, file_config.get('db', macro))
sys.exit(0)
def _server_side_cursors(options, opt_str, value, parser):
db_opts['server_side_cursors'] = True
def _engine_strategy(options, opt_str, value, parser):
if value:
db_opts['strategy'] = value
pre_configure = []
post_configure = []
def pre(fn):
pre_configure.append(fn)
return fn
def post(fn):
post_configure.append(fn)
return fn
@pre
def _setup_options(opt, file_config):
global options
options = opt
@pre
def _monkeypatch_cdecimal(options, file_config):
if options.cdecimal:
import cdecimal
sys.modules['decimal'] = cdecimal
@post
def _engine_uri(options, file_config):
global db_label, db_url
if options.dburi:
db_url = options.dburi
db_label = db_url[:db_url.index(':')]
elif options.db:
db_label = options.db
db_url = None
if db_url is None:
if db_label not in file_config.options('db'):
raise RuntimeError(
"Unknown URI specifier '%s'. Specify --dbs for known uris."
% db_label)
db_url = file_config.get('db', db_label)
@post
def _require(options, file_config):
if not(options.require or
(file_config.has_section('require') and
file_config.items('require'))):
return
try:
import pkg_resources
except ImportError:
raise RuntimeError("setuptools is required for version requirements")
cmdline = []
for requirement in options.require:
pkg_resources.require(requirement)
cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])
if file_config.has_section('require'):
for label, requirement in file_config.items('require'):
if not label == db_label or label.startswith('%s.' % db_label):
continue
seen = [c for c in cmdline if requirement.startswith(c)]
if seen:
continue
pkg_resources.require(requirement)
@post
def _engine_pool(options, file_config):
if options.mockpool:
from sqlalchemy import pool
db_opts['poolclass'] = pool.AssertionPool
@post
def _create_testing_engine(options, file_config):
from sqlalchemy.testing import engines, config
from sqlalchemy import testing
global db
config.db = testing.db = db = engines.testing_engine(db_url, db_opts)
config.db.connect().close()
config.db_opts = db_opts
config.db_url = db_url
@post
def _prep_testing_database(options, file_config):
from sqlalchemy.testing import engines
from sqlalchemy import schema, inspect
# also create alt schemas etc. here?
if options.dropfirst:
e = engines.utf8_engine()
inspector = inspect(e)
try:
view_names = inspector.get_view_names()
except NotImplementedError:
pass
else:
for vname in view_names:
e.execute(schema._DropView(schema.Table(vname, schema.MetaData())))
try:
view_names = inspector.get_view_names(schema="test_schema")
except NotImplementedError:
pass
else:
for vname in view_names:
e.execute(schema._DropView(
schema.Table(vname,
schema.MetaData(), schema="test_schema")))
for tname in reversed(inspector.get_table_names(order_by="foreign_key")):
e.execute(schema.DropTable(schema.Table(tname, schema.MetaData())))
for tname in reversed(inspector.get_table_names(
order_by="foreign_key", schema="test_schema")):
e.execute(schema.DropTable(
schema.Table(tname, schema.MetaData(), schema="test_schema")))
e.dispose()
@post
def _set_table_options(options, file_config):
from sqlalchemy.testing import schema
table_options = schema.table_options
for spec in options.tableopts:
key, value = spec.split('=')
table_options[key] = value
if options.mysql_engine:
table_options['mysql_engine'] = options.mysql_engine
@post
def _reverse_topological(options, file_config):
if options.reversetop:
from sqlalchemy.orm.util import randomize_unitofwork
randomize_unitofwork()
def _requirements_opt(options, opt_str, value, parser):
_setup_requirements(value)
@post
def _requirements(options, file_config):
requirement_cls = file_config.get('sqla_testing', "requirement_cls")
_setup_requirements(requirement_cls)
def _setup_requirements(argument):
from sqlalchemy.testing import config
from sqlalchemy import testing
if config.requirements is not None:
return
modname, clsname = argument.split(":")
# importlib.import_module() only introduced in 2.7, a little
# late
mod = __import__(modname)
for component in modname.split(".")[1:]:
mod = getattr(mod, component)
req_cls = getattr(mod, clsname)
config.requirements = testing.requires = req_cls(config)
@post
def _post_setup_options(opt, file_config):
from sqlalchemy.testing import config
config.options = options
config.file_config = file_config
@post
def _setup_profiling(options, file_config):
from sqlalchemy.testing import profiling
profiling._profile_stats = profiling.ProfileStatsFile(
file_config.get('sqla_testing', 'profile_file'))
class NoseSQLAlchemy(Plugin):
"""
Handles the setup and extra properties required for testing SQLAlchemy
"""
enabled = True
name = 'sqla_testing'
score = 100
def options(self, parser, env=os.environ):
Plugin.options(self, parser, env)
opt = parser.add_option
opt("--log-info", action="callback", type="string", callback=_log,
help="turn on info logging for <LOG> (multiple OK)")
opt("--log-debug", action="callback", type="string", callback=_log,
help="turn on debug logging for <LOG> (multiple OK)")
opt("--require", action="append", dest="require", default=[],
help="require a particular driver or module version (multiple OK)")
opt("--db", action="store", dest="db", default="default",
help="Use prefab database uri")
opt('--dbs', action='callback', callback=_list_dbs,
help="List available prefab dbs")
opt("--dburi", action="store", dest="dburi",
help="Database uri (overrides --db)")
opt("--dropfirst", action="store_true", dest="dropfirst",
help="Drop all tables in the target database first")
opt("--mockpool", action="store_true", dest="mockpool",
help="Use mock pool (asserts only one connection used)")
opt("--low-connections", action="store_true", dest="low_connections",
help="Use a low number of distinct connections - i.e. for Oracle TNS"
)
opt("--enginestrategy", action="callback", type="string",
callback=_engine_strategy,
help="Engine strategy (plain or threadlocal, defaults to plain)")
opt("--reversetop", action="store_true", dest="reversetop", default=False,
help="Use a random-ordering set implementation in the ORM (helps "
"reveal dependency issues)")
opt("--requirements", action="callback", type="string",
callback=_requirements_opt,
help="requirements class for testing, overrides setup.cfg")
opt("--with-cdecimal", action="store_true", dest="cdecimal", default=False,
help="Monkeypatch the cdecimal library into Python 'decimal' for all tests")
opt("--unhashable", action="store_true", dest="unhashable", default=False,
help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
help="Disallow SQLAlchemy from performing == on mapped test objects.")
opt("--truthless", action="store_true", dest="truthless", default=False,
help="Disallow SQLAlchemy from truth-evaluating mapped test objects.")
opt("--serverside", action="callback", callback=_server_side_cursors,
help="Turn on server side cursors for PG")
opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
help="Use the specified MySQL storage engine for all tables, default is "
"a db-default/InnoDB combo.")
opt("--table-option", action="append", dest="tableopts", default=[],
help="Add a dialect-specific table option, key=value")
opt("--write-profiles", action="store_true", dest="write_profiles", default=False,
help="Write/update profiling data.")
global file_config
file_config = ConfigParser.ConfigParser()
file_config.read(['setup.cfg', 'test.cfg'])
def configure(self, options, conf):
Plugin.configure(self, options, conf)
self.options = options
for fn in pre_configure:
fn(self.options, file_config)
def begin(self):
# Lazy setup of other options (post coverage)
for fn in post_configure:
fn(self.options, file_config)
# late imports, has to happen after config as well
# as nose plugins like coverage
global util, fixtures, engines, exclusions, \
assertions, warnings, profiling,\
config
from sqlalchemy.testing import fixtures, engines, exclusions, \
assertions, warnings, profiling, config
from sqlalchemy import util
def describeTest(self, test):
return ""
def wantFunction(self, fn):
if fn.__module__.startswith('sqlalchemy.testing'):
return False
def wantClass(self, cls):
"""Return true if you want the main test selector to collect
tests from this class, false if you don't, and None if you don't
care.
:Parameters:
cls : class
The class being examined by the selector
"""
if not issubclass(cls, fixtures.TestBase):
return False
elif cls.__name__.startswith('_'):
return False
else:
return True
def _do_skips(self, cls):
from sqlalchemy.testing import config
if hasattr(cls, '__requires__'):
def test_suite():
return 'ok'
test_suite.__name__ = cls.__name__
for requirement in cls.__requires__:
check = getattr(config.requirements, requirement)
if not check.enabled:
raise SkipTest(
check.reason if check.reason
else
(
"'%s' unsupported on DB implementation '%s'" % (
cls.__name__, config.db.name
)
)
)
if cls.__unsupported_on__:
spec = exclusions.db_spec(*cls.__unsupported_on__)
if spec(config.db):
raise SkipTest(
"'%s' unsupported on DB implementation '%s'" % (
cls.__name__, config.db.name)
)
if getattr(cls, '__only_on__', None):
spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
if not spec(config.db):
raise SkipTest(
"'%s' unsupported on DB implementation '%s'" % (
cls.__name__, config.db.name)
)
if getattr(cls, '__skip_if__', False):
for c in getattr(cls, '__skip_if__'):
if c():
raise SkipTest("'%s' skipped by %s" % (
cls.__name__, c.__name__)
)
for db, op, spec in getattr(cls, '__excluded_on__', ()):
exclusions.exclude(db, op, spec,
"'%s' unsupported on DB %s version %s" % (
cls.__name__, config.db.name,
exclusions._server_version(config.db)))
def beforeTest(self, test):
warnings.resetwarnings()
profiling._current_test = test.id()
def afterTest(self, test):
engines.testing_reaper._after_test_ctx()
warnings.resetwarnings()
def _setup_engine(self, ctx):
if getattr(ctx, '__engine_options__', None):
global _existing_engine
_existing_engine = config.db
config.db = engines.testing_engine(options=ctx.__engine_options__)
def _restore_engine(self, ctx):
global _existing_engine
if _existing_engine is not None:
config.db = _existing_engine
_existing_engine = None
def startContext(self, ctx):
if not isinstance(ctx, type) \
or not issubclass(ctx, fixtures.TestBase):
return
self._do_skips(ctx)
self._setup_engine(ctx)
def stopContext(self, ctx):
if not isinstance(ctx, type) \
or not issubclass(ctx, fixtures.TestBase):
return
engines.testing_reaper._stop_test_ctx()
if not options.low_connections:
assertions.global_cleanup_assertions()
self._restore_engine(ctx)

View File

@@ -0,0 +1,294 @@
"""Profiling support for unit and performance tests.
These are special purpose profiling methods which operate
in a more fine-grained way than nose's profiling plugin.
"""
import os
import sys
from .util import gc_collect, decorator
from . import config
from nose import SkipTest
import pstats
import time
import collections
from .. import util
try:
import cProfile
except ImportError:
cProfile = None
from ..util import jython, pypy, win32, update_wrapper
_current_test = None
def profiled(target=None, **target_opts):
"""Function profiling.
@profiled()
or
@profiled(report=True, sort=('calls',), limit=20)
Outputs profiling info for a decorated function.
"""
profile_config = {'targets': set(),
'report': True,
'print_callers': False,
'print_callees': False,
'graphic': False,
'sort': ('time', 'calls'),
'limit': None}
if target is None:
target = 'anonymous_target'
filename = "%s.prof" % target
@decorator
def decorate(fn, *args, **kw):
elapsed, load_stats, result = _profile(
filename, fn, *args, **kw)
graphic = target_opts.get('graphic', profile_config['graphic'])
if graphic:
os.system("runsnake %s" % filename)
else:
report = target_opts.get('report', profile_config['report'])
if report:
sort_ = target_opts.get('sort', profile_config['sort'])
limit = target_opts.get('limit', profile_config['limit'])
print ("Profile report for target '%s' (%s)" % (
target, filename)
)
stats = load_stats()
stats.sort_stats(*sort_)
if limit:
stats.print_stats(limit)
else:
stats.print_stats()
print_callers = target_opts.get(
'print_callers', profile_config['print_callers'])
if print_callers:
stats.print_callers()
print_callees = target_opts.get(
'print_callees', profile_config['print_callees'])
if print_callees:
stats.print_callees()
os.unlink(filename)
return result
return decorate
class ProfileStatsFile(object):
""""Store per-platform/fn profiling results in a file.
We're still targeting Py2.5, 2.4 on 0.7 with no dependencies,
so no json lib :( need to roll something silly
"""
def __init__(self, filename):
self.write = (
config.options is not None and
config.options.write_profiles
)
self.fname = os.path.abspath(filename)
self.short_fname = os.path.split(self.fname)[-1]
self.data = collections.defaultdict(
lambda: collections.defaultdict(dict))
self._read()
if self.write:
# rewrite for the case where features changed,
# etc.
self._write()
@util.memoized_property
def platform_key(self):
dbapi_key = config.db.name + "_" + config.db.driver
# keep it at 2.7, 3.1, 3.2, etc. for now.
py_version = '.'.join([str(v) for v in sys.version_info[0:2]])
platform_tokens = [py_version]
platform_tokens.append(dbapi_key)
if jython:
platform_tokens.append("jython")
if pypy:
platform_tokens.append("pypy")
if win32:
platform_tokens.append("win")
_has_cext = config.requirements._has_cextensions()
platform_tokens.append(_has_cext and "cextensions" or "nocextensions")
return "_".join(platform_tokens)
def has_stats(self):
test_key = _current_test
return (
test_key in self.data and
self.platform_key in self.data[test_key]
)
def result(self, callcount):
test_key = _current_test
per_fn = self.data[test_key]
per_platform = per_fn[self.platform_key]
if 'counts' not in per_platform:
per_platform['counts'] = counts = []
else:
counts = per_platform['counts']
if 'current_count' not in per_platform:
per_platform['current_count'] = current_count = 0
else:
current_count = per_platform['current_count']
has_count = len(counts) > current_count
if not has_count:
counts.append(callcount)
if self.write:
self._write()
result = None
else:
result = per_platform['lineno'], counts[current_count]
per_platform['current_count'] += 1
return result
def _header(self):
return \
"# %s\n"\
"# This file is written out on a per-environment basis.\n"\
"# For each test in aaa_profiling, the corresponding function and \n"\
"# environment is located within this file. If it doesn't exist,\n"\
"# the test is skipped.\n"\
"# If a callcount does exist, it is compared to what we received. \n"\
"# assertions are raised if the counts do not match.\n"\
"# \n"\
"# To add a new callcount test, apply the function_call_count \n"\
"# decorator and re-run the tests using the --write-profiles \n"\
"# option - this file will be rewritten including the new count.\n"\
"# \n"\
"" % (self.fname)
def _read(self):
try:
profile_f = open(self.fname)
except IOError:
return
for lineno, line in enumerate(profile_f):
line = line.strip()
if not line or line.startswith("#"):
continue
test_key, platform_key, counts = line.split()
per_fn = self.data[test_key]
per_platform = per_fn[platform_key]
c = [int(count) for count in counts.split(",")]
per_platform['counts'] = c
per_platform['lineno'] = lineno + 1
per_platform['current_count'] = 0
profile_f.close()
def _write(self):
print("Writing profile file %s" % self.fname)
profile_f = open(self.fname, "w")
profile_f.write(self._header())
for test_key in sorted(self.data):
per_fn = self.data[test_key]
profile_f.write("\n# TEST: %s\n\n" % test_key)
for platform_key in sorted(per_fn):
per_platform = per_fn[platform_key]
c = ",".join(str(count) for count in per_platform['counts'])
profile_f.write("%s %s %s\n" % (test_key, platform_key, c))
profile_f.close()
def function_call_count(variance=0.05):
"""Assert a target for a test case's function call count.
The main purpose of this assertion is to detect changes in
callcounts for various functions - the actual number is not as important.
Callcounts are stored in a file keyed to Python version and OS platform
information. This file is generated automatically for new tests,
and versioned so that unexpected changes in callcounts will be detected.
"""
def decorate(fn):
def wrap(*args, **kw):
if cProfile is None:
raise SkipTest("cProfile is not installed")
if not _profile_stats.has_stats() and not _profile_stats.write:
# run the function anyway, to support dependent tests
# (not a great idea but we have these in test_zoomark)
fn(*args, **kw)
raise SkipTest("No profiling stats available on this "
"platform for this function. Run tests with "
"--write-profiles to add statistics to %s for "
"this platform." % _profile_stats.short_fname)
gc_collect()
timespent, load_stats, fn_result = _profile(
fn, *args, **kw
)
stats = load_stats()
callcount = stats.total_calls
expected = _profile_stats.result(callcount)
if expected is None:
expected_count = None
else:
line_no, expected_count = expected
print("Pstats calls: %d Expected %s" % (
callcount,
expected_count
)
)
stats.print_stats()
#stats.print_callers()
if expected_count:
deviance = int(callcount * variance)
if abs(callcount - expected_count) > deviance:
raise AssertionError(
"Adjusted function call count %s not within %s%% "
"of expected %s. (Delete line %d of file %s to "
"regenerate this callcount, when tests are run "
"with --write-profiles.)"
% (
callcount, (variance * 100),
expected_count, line_no,
_profile_stats.fname))
return fn_result
return update_wrapper(wrap, fn)
return decorate
def _profile(fn, *args, **kw):
filename = "%s.prof" % fn.__name__
def load_stats():
st = pstats.Stats(filename)
os.unlink(filename)
return st
began = time.time()
cProfile.runctx('result = fn(*args, **kw)', globals(), locals(),
filename=filename)
ended = time.time()
return ended - began, load_stats, locals()['result']

View File

@@ -0,0 +1,438 @@
"""Global database feature support policy.
Provides decorators to mark tests requiring specific feature support from the
target database.
External dialect test suites should subclass SuiteRequirements
to provide specific inclusion/exlusions.
"""
from . import exclusions, config
class Requirements(object):
def __init__(self, config):
self.config = config
@property
def db(self):
return config.db
class SuiteRequirements(Requirements):
@property
def create_table(self):
"""target platform can emit basic CreateTable DDL."""
return exclusions.open()
@property
def drop_table(self):
"""target platform can emit basic DropTable DDL."""
return exclusions.open()
@property
def foreign_keys(self):
"""Target database must support foreign keys."""
return exclusions.open()
@property
def on_update_cascade(self):
""""target database must support ON UPDATE..CASCADE behavior in
foreign keys."""
return exclusions.open()
@property
def deferrable_fks(self):
return exclusions.closed()
@property
def on_update_or_deferrable_fks(self):
# TODO: exclusions should be composable,
# somehow only_if([x, y]) isn't working here, negation/conjunctions
# getting confused.
return exclusions.only_if(
lambda: self.on_update_cascade.enabled or self.deferrable_fks.enabled
)
@property
def self_referential_foreign_keys(self):
"""Target database must support self-referential foreign keys."""
return exclusions.open()
@property
def foreign_key_ddl(self):
"""Target database must support the DDL phrases for FOREIGN KEY."""
return exclusions.open()
@property
def named_constraints(self):
"""target database must support names for constraints."""
return exclusions.open()
@property
def subqueries(self):
"""Target database must support subqueries."""
return exclusions.open()
@property
def offset(self):
"""target database can render OFFSET, or an equivalent, in a SELECT."""
return exclusions.open()
@property
def boolean_col_expressions(self):
"""Target database must support boolean expressions as columns"""
return exclusions.closed()
@property
def nullsordering(self):
"""Target backends that support nulls ordering."""
return exclusions.closed()
@property
def standalone_binds(self):
"""target database/driver supports bound parameters as column expressions
without being in the context of a typed column.
"""
return exclusions.closed()
@property
def intersect(self):
"""Target database must support INTERSECT or equivalent."""
return exclusions.closed()
@property
def except_(self):
"""Target database must support EXCEPT or equivalent (i.e. MINUS)."""
return exclusions.closed()
@property
def window_functions(self):
"""Target database must support window functions."""
return exclusions.closed()
@property
def autoincrement_insert(self):
"""target platform generates new surrogate integer primary key values
when insert() is executed, excluding the pk column."""
return exclusions.open()
@property
def empty_inserts(self):
"""target platform supports INSERT with no values, i.e.
INSERT DEFAULT VALUES or equivalent."""
return exclusions.only_if(
lambda: self.config.db.dialect.supports_empty_insert or \
self.config.db.dialect.supports_default_values,
"empty inserts not supported"
)
@property
def insert_from_select(self):
"""target platform supports INSERT from a SELECT."""
return exclusions.open()
@property
def returning(self):
"""target platform supports RETURNING."""
return exclusions.only_if(
lambda: self.config.db.dialect.implicit_returning,
"'returning' not supported by database"
)
@property
def denormalized_names(self):
"""Target database must have 'denormalized', i.e.
UPPERCASE as case insensitive names."""
return exclusions.skip_if(
lambda: not self.db.dialect.requires_name_normalize,
"Backend does not require denormalized names."
)
@property
def multivalues_inserts(self):
"""target database must support multiple VALUES clauses in an
INSERT statement."""
return exclusions.skip_if(
lambda: not self.db.dialect.supports_multivalues_insert,
"Backend does not support multirow inserts."
)
@property
def implements_get_lastrowid(self):
""""target dialect implements the executioncontext.get_lastrowid()
method without reliance on RETURNING.
"""
return exclusions.open()
@property
def emulated_lastrowid(self):
""""target dialect retrieves cursor.lastrowid, or fetches
from a database-side function after an insert() construct executes,
within the get_lastrowid() method.
Only dialects that "pre-execute", or need RETURNING to get last
inserted id, would return closed/fail/skip for this.
"""
return exclusions.closed()
@property
def dbapi_lastrowid(self):
""""target platform includes a 'lastrowid' accessor on the DBAPI
cursor object.
"""
return exclusions.closed()
@property
def views(self):
"""Target database must support VIEWs."""
return exclusions.closed()
@property
def schemas(self):
"""Target database must support external schemas, and have one
named 'test_schema'."""
return exclusions.closed()
@property
def sequences(self):
"""Target database must support SEQUENCEs."""
return exclusions.only_if([
lambda: self.config.db.dialect.supports_sequences
], "no sequence support")
@property
def sequences_optional(self):
"""Target database supports sequences, but also optionally
as a means of generating new PK values."""
return exclusions.only_if([
lambda: self.config.db.dialect.supports_sequences and \
self.config.db.dialect.sequences_optional
], "no sequence support, or sequences not optional")
@property
def reflects_pk_names(self):
return exclusions.closed()
@property
def table_reflection(self):
return exclusions.open()
@property
def view_reflection(self):
return self.views
@property
def schema_reflection(self):
return self.schemas
@property
def primary_key_constraint_reflection(self):
return exclusions.open()
@property
def foreign_key_constraint_reflection(self):
return exclusions.open()
@property
def index_reflection(self):
return exclusions.open()
@property
def unbounded_varchar(self):
"""Target database must support VARCHAR with no length"""
return exclusions.open()
@property
def unicode_data(self):
"""Target database/dialect must support Python unicode objects with
non-ASCII characters represented, delivered as bound parameters
as well as in result rows.
"""
return exclusions.open()
@property
def unicode_ddl(self):
"""Target driver must support some degree of non-ascii symbol names."""
return exclusions.closed()
@property
def datetime(self):
"""target dialect supports representation of Python
datetime.datetime() objects."""
return exclusions.open()
@property
def datetime_microseconds(self):
"""target dialect supports representation of Python
datetime.datetime() with microsecond objects."""
return exclusions.open()
@property
def datetime_historic(self):
"""target dialect supports representation of Python
datetime.datetime() objects with historic (pre 1970) values."""
return exclusions.closed()
@property
def date(self):
"""target dialect supports representation of Python
datetime.date() objects."""
return exclusions.open()
@property
def date_historic(self):
"""target dialect supports representation of Python
datetime.datetime() objects with historic (pre 1970) values."""
return exclusions.closed()
@property
def time(self):
"""target dialect supports representation of Python
datetime.time() objects."""
return exclusions.open()
@property
def time_microseconds(self):
"""target dialect supports representation of Python
datetime.time() with microsecond objects."""
return exclusions.open()
@property
def precision_numerics_general(self):
"""target backend has general support for moderately high-precision
numerics."""
return exclusions.open()
@property
def precision_numerics_enotation_small(self):
"""target backend supports Decimal() objects using E notation
to represent very small values."""
return exclusions.closed()
@property
def precision_numerics_enotation_large(self):
"""target backend supports Decimal() objects using E notation
to represent very large values."""
return exclusions.closed()
@property
def precision_numerics_many_significant_digits(self):
"""target backend supports values with many digits on both sides,
such as 319438950232418390.273596, 87673.594069654243
"""
return exclusions.closed()
@property
def precision_numerics_retains_significant_digits(self):
"""A precision numeric type will return empty significant digits,
i.e. a value such as 10.000 will come back in Decimal form with
the .000 maintained."""
return exclusions.closed()
@property
def text_type(self):
"""Target database must support an unbounded Text() "
"type such as TEXT or CLOB"""
return exclusions.open()
@property
def empty_strings_varchar(self):
"""target database can persist/return an empty string with a
varchar.
"""
return exclusions.open()
@property
def empty_strings_text(self):
"""target database can persist/return an empty string with an
unbounded text."""
return exclusions.open()
@property
def update_from(self):
"""Target must support UPDATE..FROM syntax"""
return exclusions.closed()
@property
def update_where_target_in_subquery(self):
"""Target must support UPDATE where the same table is present in a
subquery in the WHERE clause.
This is an ANSI-standard syntax that apparently MySQL can't handle,
such as:
UPDATE documents SET flag=1 WHERE documents.title IN
(SELECT max(documents.title) AS title
FROM documents GROUP BY documents.user_id
)
"""
return exclusions.open()
@property
def mod_operator_as_percent_sign(self):
"""target database must use a plain percent '%' as the 'modulus'
operator."""
return exclusions.closed()
@property
def unicode_connections(self):
"""Target driver must support non-ASCII characters being passed at all."""
return exclusions.open()
@property
def skip_mysql_on_windows(self):
"""Catchall for a large variety of MySQL on Windows failures"""
return exclusions.open()
def _has_mysql_on_windows(self):
return False
def _has_mysql_fully_case_sensitive(self):
return False

View File

@@ -0,0 +1,43 @@
#!/usr/bin/env python
"""
Nose test runner module.
This script is a front-end to "nosetests" which
installs SQLAlchemy's testing plugin into the local environment.
The script is intended to be used by third-party dialects and extensions
that run within SQLAlchemy's testing framework. The runner can
be invoked via::
python -m sqlalchemy.testing.runner
The script is then essentially the same as the "nosetests" script, including
all of the usual Nose options. The test environment requires that a
setup.cfg is locally present including various required options.
Note that when using this runner, Nose's "coverage" plugin will not be
able to provide coverage for SQLAlchemy itself, since SQLAlchemy is
imported into sys.modules before coverage is started. The special
script sqla_nose.py is provided as a top-level script which loads the
plugin in a special (somewhat hacky) way so that coverage against
SQLAlchemy itself is possible.
"""
from sqlalchemy.testing.plugin.noseplugin import NoseSQLAlchemy
import nose
def main():
nose.main(addplugins=[NoseSQLAlchemy()])
def setup_py_test():
"""Runner to use for the 'test_suite' entry of your setup.py.
Prevents any name clash shenanigans from the command line
argument "test" that the "setup.py test" command sends
to nose.
"""
nose.main(addplugins=[NoseSQLAlchemy()], argv=['runner'])

View File

@@ -0,0 +1,95 @@
from . import exclusions
from .. import schema, event
from . import config
__all__ = 'Table', 'Column',
table_options = {}
def Table(*args, **kw):
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
test_opts = dict([(k, kw.pop(k)) for k in kw.keys()
if k.startswith('test_')])
kw.update(table_options)
if exclusions.against('mysql'):
if 'mysql_engine' not in kw and 'mysql_type' not in kw:
if 'test_needs_fk' in test_opts or 'test_needs_acid' in test_opts:
kw['mysql_engine'] = 'InnoDB'
else:
kw['mysql_engine'] = 'MyISAM'
# Apply some default cascading rules for self-referential foreign keys.
# MySQL InnoDB has some issues around seleting self-refs too.
if exclusions.against('firebird'):
table_name = args[0]
unpack = (config.db.dialect.
identifier_preparer.unformat_identifiers)
# Only going after ForeignKeys in Columns. May need to
# expand to ForeignKeyConstraint too.
fks = [fk
for col in args if isinstance(col, schema.Column)
for fk in col.foreign_keys]
for fk in fks:
# root around in raw spec
ref = fk._colspec
if isinstance(ref, schema.Column):
name = ref.table.name
else:
# take just the table name: on FB there cannot be
# a schema, so the first element is always the
# table name, possibly followed by the field name
name = unpack(ref)[0]
if name == table_name:
if fk.ondelete is None:
fk.ondelete = 'CASCADE'
if fk.onupdate is None:
fk.onupdate = 'CASCADE'
return schema.Table(*args, **kw)
def Column(*args, **kw):
"""A schema.Column wrapper/hook for dialect-specific tweaks."""
test_opts = dict([(k, kw.pop(k)) for k in kw.keys()
if k.startswith('test_')])
if not config.requirements.foreign_key_ddl.enabled:
args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
col = schema.Column(*args, **kw)
if 'test_needs_autoincrement' in test_opts and \
kw.get('primary_key', False):
# allow any test suite to pick up on this
col.info['test_needs_autoincrement'] = True
# hardcoded rule for firebird, oracle; this should
# be moved out
if exclusions.against('firebird', 'oracle'):
def add_seq(c, tbl):
c._init_items(
schema.Sequence(_truncate_name(
config.db.dialect, tbl.name + '_' + c.name + '_seq'),
optional=True)
)
event.listen(col, 'after_parent_attach', add_seq, propagate=True)
return col
def _truncate_name(dialect, name):
if len(name) > dialect.max_identifier_length:
return name[0:max(dialect.max_identifier_length - 6, 0)] + \
"_" + hex(hash(name) % 64)[2:]
else:
return name

View File

@@ -0,0 +1,8 @@
from sqlalchemy.testing.suite.test_ddl import *
from sqlalchemy.testing.suite.test_insert import *
from sqlalchemy.testing.suite.test_sequence import *
from sqlalchemy.testing.suite.test_results import *
from sqlalchemy.testing.suite.test_update_delete import *
from sqlalchemy.testing.suite.test_reflection import *
from sqlalchemy.testing.suite.test_types import *

View File

@@ -0,0 +1,48 @@
from __future__ import with_statement
from .. import fixtures, config, util
from ..config import requirements
from ..assertions import eq_
from sqlalchemy import Table, Column, Integer, String
class TableDDLTest(fixtures.TestBase):
def _simple_fixture(self):
return Table('test_table', self.metadata,
Column('id', Integer, primary_key=True, autoincrement=False),
Column('data', String(50))
)
def _simple_roundtrip(self, table):
with config.db.begin() as conn:
conn.execute(table.insert().values((1, 'some data')))
result = conn.execute(table.select())
eq_(
result.first(),
(1, 'some data')
)
@requirements.create_table
@util.provide_metadata
def test_create_table(self):
table = self._simple_fixture()
table.create(
config.db, checkfirst=False
)
self._simple_roundtrip(table)
@requirements.drop_table
@util.provide_metadata
def test_drop_table(self):
table = self._simple_fixture()
table.create(
config.db, checkfirst=False
)
table.drop(
config.db, checkfirst=False
)
__all__ = ('TableDDLTest', )

View File

@@ -0,0 +1,207 @@
from .. import fixtures, config
from ..config import requirements
from .. import exclusions
from ..assertions import eq_
from .. import engines
from sqlalchemy import Integer, String, select, util
from ..schema import Table, Column
class LastrowidTest(fixtures.TablesTest):
run_deletes = 'each'
__requires__ = 'implements_get_lastrowid', 'autoincrement_insert'
__engine_options__ = {"implicit_returning": False}
@classmethod
def define_tables(cls, metadata):
Table('autoinc_pk', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('data', String(50))
)
Table('manual_pk', metadata,
Column('id', Integer, primary_key=True, autoincrement=False),
Column('data', String(50))
)
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
eq_(
row,
(config.db.dialect.default_sequence_base, "some data")
)
def test_autoincrement_on_insert(self):
config.db.execute(
self.tables.autoinc_pk.insert(),
data="some data"
)
self._assert_round_trip(self.tables.autoinc_pk, config.db)
def test_last_inserted_id(self):
r = config.db.execute(
self.tables.autoinc_pk.insert(),
data="some data"
)
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
eq_(
r.inserted_primary_key,
[pk]
)
@exclusions.fails_if(lambda: util.pypy, "lastrowid not maintained after "
"connection close")
@requirements.dbapi_lastrowid
def test_native_lastrowid_autoinc(self):
r = config.db.execute(
self.tables.autoinc_pk.insert(),
data="some data"
)
lastrowid = r.lastrowid
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
eq_(
lastrowid, pk
)
class InsertBehaviorTest(fixtures.TablesTest):
run_deletes = 'each'
@classmethod
def define_tables(cls, metadata):
Table('autoinc_pk', metadata,
Column('id', Integer, primary_key=True, \
test_needs_autoincrement=True),
Column('data', String(50))
)
def test_autoclose_on_insert(self):
if requirements.returning.enabled:
engine = engines.testing_engine(
options={'implicit_returning': False})
else:
engine = config.db
r = engine.execute(
self.tables.autoinc_pk.insert(),
data="some data"
)
assert r.closed
assert r.is_insert
assert not r.returns_rows
@requirements.returning
def test_autoclose_on_insert_implicit_returning(self):
r = config.db.execute(
self.tables.autoinc_pk.insert(),
data="some data"
)
assert r.closed
assert r.is_insert
assert not r.returns_rows
@requirements.empty_inserts
def test_empty_insert(self):
r = config.db.execute(
self.tables.autoinc_pk.insert(),
)
assert r.closed
r = config.db.execute(
self.tables.autoinc_pk.select().\
where(self.tables.autoinc_pk.c.id != None)
)
assert len(r.fetchall())
@requirements.insert_from_select
def test_insert_from_select(self):
table = self.tables.autoinc_pk
config.db.execute(
table.insert(),
[
dict(data="data1"),
dict(data="data2"),
dict(data="data3"),
]
)
config.db.execute(
table.insert(inline=True).
from_select(
("id", "data",), select([table.c.id + 5, table.c.data]).where(
table.c.data.in_(["data2", "data3"]))
),
)
eq_(
config.db.execute(
select([table.c.data]).order_by(table.c.data)
).fetchall(),
[("data1", ), ("data2", ), ("data2", ),
("data3", ), ("data3", )]
)
class ReturningTest(fixtures.TablesTest):
run_deletes = 'each'
__requires__ = 'returning', 'autoincrement_insert'
__engine_options__ = {"implicit_returning": True}
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
eq_(
row,
(config.db.dialect.default_sequence_base, "some data")
)
@classmethod
def define_tables(cls, metadata):
Table('autoinc_pk', metadata,
Column('id', Integer, primary_key=True, \
test_needs_autoincrement=True),
Column('data', String(50))
)
def test_explicit_returning_pk(self):
engine = config.db
table = self.tables.autoinc_pk
r = engine.execute(
table.insert().returning(
table.c.id),
data="some data"
)
pk = r.first()[0]
fetched_pk = config.db.scalar(select([table.c.id]))
eq_(fetched_pk, pk)
def test_autoincrement_on_insert_implcit_returning(self):
config.db.execute(
self.tables.autoinc_pk.insert(),
data="some data"
)
self._assert_round_trip(self.tables.autoinc_pk, config.db)
def test_last_inserted_id_implicit_returning(self):
r = config.db.execute(
self.tables.autoinc_pk.insert(),
data="some data"
)
pk = config.db.scalar(select([self.tables.autoinc_pk.c.id]))
eq_(
r.inserted_primary_key,
[pk]
)
__all__ = ('LastrowidTest', 'InsertBehaviorTest', 'ReturningTest')

View File

@@ -0,0 +1,431 @@
from __future__ import with_statement
import sqlalchemy as sa
from sqlalchemy import exc as sa_exc
from sqlalchemy import types as sql_types
from sqlalchemy import inspect
from sqlalchemy import MetaData, Integer, String
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.testing import engines, fixtures
from sqlalchemy.testing.schema import Table, Column
from sqlalchemy.testing import eq_, assert_raises_message
from sqlalchemy import testing
from .. import config
from sqlalchemy.schema import DDL, Index
from sqlalchemy import event
metadata, users = None, None
class HasTableTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
Table('test_table', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(50))
)
def test_has_table(self):
with config.db.begin() as conn:
assert config.db.dialect.has_table(conn, "test_table")
assert not config.db.dialect.has_table(conn, "nonexistent_table")
class ComponentReflectionTest(fixtures.TablesTest):
run_inserts = run_deletes = None
@classmethod
def define_tables(cls, metadata):
cls.define_reflected_tables(metadata, None)
if testing.requires.schemas.enabled:
cls.define_reflected_tables(metadata, "test_schema")
@classmethod
def define_reflected_tables(cls, metadata, schema):
if schema:
schema_prefix = schema + "."
else:
schema_prefix = ""
if testing.requires.self_referential_foreign_keys.enabled:
users = Table('users', metadata,
Column('user_id', sa.INT, primary_key=True),
Column('test1', sa.CHAR(5), nullable=False),
Column('test2', sa.Float(5), nullable=False),
Column('parent_user_id', sa.Integer,
sa.ForeignKey('%susers.user_id' % schema_prefix)),
schema=schema,
test_needs_fk=True,
)
else:
users = Table('users', metadata,
Column('user_id', sa.INT, primary_key=True),
Column('test1', sa.CHAR(5), nullable=False),
Column('test2', sa.Float(5), nullable=False),
schema=schema,
test_needs_fk=True,
)
Table("dingalings", metadata,
Column('dingaling_id', sa.Integer, primary_key=True),
Column('address_id', sa.Integer,
sa.ForeignKey('%semail_addresses.address_id' %
schema_prefix)),
Column('data', sa.String(30)),
schema=schema,
test_needs_fk=True,
)
Table('email_addresses', metadata,
Column('address_id', sa.Integer),
Column('remote_user_id', sa.Integer,
sa.ForeignKey(users.c.user_id)),
Column('email_address', sa.String(20)),
sa.PrimaryKeyConstraint('address_id', name='email_ad_pk'),
schema=schema,
test_needs_fk=True,
)
if testing.requires.index_reflection.enabled:
cls.define_index(metadata, users)
if testing.requires.view_reflection.enabled:
cls.define_views(metadata, schema)
@classmethod
def define_index(cls, metadata, users):
Index("users_t_idx", users.c.test1, users.c.test2)
Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1)
@classmethod
def define_views(cls, metadata, schema):
for table_name in ('users', 'email_addresses'):
fullname = table_name
if schema:
fullname = "%s.%s" % (schema, table_name)
view_name = fullname + '_v'
query = "CREATE VIEW %s AS SELECT * FROM %s" % (
view_name, fullname)
event.listen(
metadata,
"after_create",
DDL(query)
)
event.listen(
metadata,
"before_drop",
DDL("DROP VIEW %s" % view_name)
)
@testing.requires.schema_reflection
def test_get_schema_names(self):
insp = inspect(testing.db)
self.assert_('test_schema' in insp.get_schema_names())
@testing.requires.schema_reflection
def test_dialect_initialize(self):
engine = engines.testing_engine()
assert not hasattr(engine.dialect, 'default_schema_name')
inspect(engine)
assert hasattr(engine.dialect, 'default_schema_name')
@testing.requires.schema_reflection
def test_get_default_schema_name(self):
insp = inspect(testing.db)
eq_(insp.default_schema_name, testing.db.dialect.default_schema_name)
@testing.provide_metadata
def _test_get_table_names(self, schema=None, table_type='table',
order_by=None):
meta = self.metadata
users, addresses, dingalings = self.tables.users, \
self.tables.email_addresses, self.tables.dingalings
insp = inspect(meta.bind)
if table_type == 'view':
table_names = insp.get_view_names(schema)
table_names.sort()
answer = ['email_addresses_v', 'users_v']
else:
table_names = insp.get_table_names(schema,
order_by=order_by)
if order_by == 'foreign_key':
answer = ['users', 'email_addresses', 'dingalings']
eq_(table_names, answer)
else:
answer = ['dingalings', 'email_addresses', 'users']
eq_(sorted(table_names), answer)
@testing.requires.table_reflection
def test_get_table_names(self):
self._test_get_table_names()
@testing.requires.table_reflection
@testing.requires.foreign_key_constraint_reflection
def test_get_table_names_fks(self):
self._test_get_table_names(order_by='foreign_key')
@testing.requires.table_reflection
@testing.requires.schemas
def test_get_table_names_with_schema(self):
self._test_get_table_names('test_schema')
@testing.requires.view_reflection
def test_get_view_names(self):
self._test_get_table_names(table_type='view')
@testing.requires.view_reflection
@testing.requires.schemas
def test_get_view_names_with_schema(self):
self._test_get_table_names('test_schema', table_type='view')
def _test_get_columns(self, schema=None, table_type='table'):
meta = MetaData(testing.db)
users, addresses, dingalings = self.tables.users, \
self.tables.email_addresses, self.tables.dingalings
table_names = ['users', 'email_addresses']
if table_type == 'view':
table_names = ['users_v', 'email_addresses_v']
insp = inspect(meta.bind)
for table_name, table in zip(table_names, (users,
addresses)):
schema_name = schema
cols = insp.get_columns(table_name, schema=schema_name)
self.assert_(len(cols) > 0, len(cols))
# should be in order
for i, col in enumerate(table.columns):
eq_(col.name, cols[i]['name'])
ctype = cols[i]['type'].__class__
ctype_def = col.type
if isinstance(ctype_def, sa.types.TypeEngine):
ctype_def = ctype_def.__class__
# Oracle returns Date for DateTime.
if testing.against('oracle') and ctype_def \
in (sql_types.Date, sql_types.DateTime):
ctype_def = sql_types.Date
# assert that the desired type and return type share
# a base within one of the generic types.
self.assert_(len(set(ctype.__mro__).
intersection(ctype_def.__mro__).intersection([
sql_types.Integer,
sql_types.Numeric,
sql_types.DateTime,
sql_types.Date,
sql_types.Time,
sql_types.String,
sql_types._Binary,
])) > 0, '%s(%s), %s(%s)' % (col.name,
col.type, cols[i]['name'], ctype))
if not col.primary_key:
assert cols[i]['default'] is None
@testing.requires.table_reflection
def test_get_columns(self):
self._test_get_columns()
@testing.requires.table_reflection
@testing.requires.schemas
def test_get_columns_with_schema(self):
self._test_get_columns(schema='test_schema')
@testing.requires.view_reflection
def test_get_view_columns(self):
self._test_get_columns(table_type='view')
@testing.requires.view_reflection
@testing.requires.schemas
def test_get_view_columns_with_schema(self):
self._test_get_columns(schema='test_schema', table_type='view')
@testing.provide_metadata
def _test_get_pk_constraint(self, schema=None):
meta = self.metadata
users, addresses = self.tables.users, self.tables.email_addresses
insp = inspect(meta.bind)
users_cons = insp.get_pk_constraint(users.name, schema=schema)
users_pkeys = users_cons['constrained_columns']
eq_(users_pkeys, ['user_id'])
addr_cons = insp.get_pk_constraint(addresses.name, schema=schema)
addr_pkeys = addr_cons['constrained_columns']
eq_(addr_pkeys, ['address_id'])
with testing.requires.reflects_pk_names.fail_if():
eq_(addr_cons['name'], 'email_ad_pk')
@testing.requires.primary_key_constraint_reflection
def test_get_pk_constraint(self):
self._test_get_pk_constraint()
@testing.requires.table_reflection
@testing.requires.primary_key_constraint_reflection
@testing.requires.schemas
def test_get_pk_constraint_with_schema(self):
self._test_get_pk_constraint(schema='test_schema')
@testing.requires.table_reflection
@testing.provide_metadata
def test_deprecated_get_primary_keys(self):
meta = self.metadata
users = self.tables.users
insp = Inspector(meta.bind)
assert_raises_message(
sa_exc.SADeprecationWarning,
"Call to deprecated method get_primary_keys."
" Use get_pk_constraint instead.",
insp.get_primary_keys, users.name
)
@testing.provide_metadata
def _test_get_foreign_keys(self, schema=None):
meta = self.metadata
users, addresses, dingalings = self.tables.users, \
self.tables.email_addresses, self.tables.dingalings
insp = inspect(meta.bind)
expected_schema = schema
# users
users_fkeys = insp.get_foreign_keys(users.name,
schema=schema)
fkey1 = users_fkeys[0]
with testing.requires.named_constraints.fail_if():
self.assert_(fkey1['name'] is not None)
eq_(fkey1['referred_schema'], expected_schema)
eq_(fkey1['referred_table'], users.name)
eq_(fkey1['referred_columns'], ['user_id', ])
if testing.requires.self_referential_foreign_keys.enabled:
eq_(fkey1['constrained_columns'], ['parent_user_id'])
#addresses
addr_fkeys = insp.get_foreign_keys(addresses.name,
schema=schema)
fkey1 = addr_fkeys[0]
with testing.requires.named_constraints.fail_if():
self.assert_(fkey1['name'] is not None)
eq_(fkey1['referred_schema'], expected_schema)
eq_(fkey1['referred_table'], users.name)
eq_(fkey1['referred_columns'], ['user_id', ])
eq_(fkey1['constrained_columns'], ['remote_user_id'])
@testing.requires.foreign_key_constraint_reflection
def test_get_foreign_keys(self):
self._test_get_foreign_keys()
@testing.requires.foreign_key_constraint_reflection
@testing.requires.schemas
def test_get_foreign_keys_with_schema(self):
self._test_get_foreign_keys(schema='test_schema')
@testing.provide_metadata
def _test_get_indexes(self, schema=None):
meta = self.metadata
users, addresses, dingalings = self.tables.users, \
self.tables.email_addresses, self.tables.dingalings
# The database may decide to create indexes for foreign keys, etc.
# so there may be more indexes than expected.
insp = inspect(meta.bind)
indexes = insp.get_indexes('users', schema=schema)
expected_indexes = [
{'unique': False,
'column_names': ['test1', 'test2'],
'name': 'users_t_idx'},
{'unique': False,
'column_names': ['user_id', 'test2', 'test1'],
'name': 'users_all_idx'}
]
index_names = [d['name'] for d in indexes]
for e_index in expected_indexes:
assert e_index['name'] in index_names
index = indexes[index_names.index(e_index['name'])]
for key in e_index:
eq_(e_index[key], index[key])
@testing.requires.index_reflection
def test_get_indexes(self):
self._test_get_indexes()
@testing.requires.index_reflection
@testing.requires.schemas
def test_get_indexes_with_schema(self):
self._test_get_indexes(schema='test_schema')
@testing.provide_metadata
def _test_get_view_definition(self, schema=None):
meta = self.metadata
users, addresses, dingalings = self.tables.users, \
self.tables.email_addresses, self.tables.dingalings
view_name1 = 'users_v'
view_name2 = 'email_addresses_v'
insp = inspect(meta.bind)
v1 = insp.get_view_definition(view_name1, schema=schema)
self.assert_(v1)
v2 = insp.get_view_definition(view_name2, schema=schema)
self.assert_(v2)
@testing.requires.view_reflection
def test_get_view_definition(self):
self._test_get_view_definition()
@testing.requires.view_reflection
@testing.requires.schemas
def test_get_view_definition_with_schema(self):
self._test_get_view_definition(schema='test_schema')
@testing.only_on("postgresql", "PG specific feature")
@testing.provide_metadata
def _test_get_table_oid(self, table_name, schema=None):
meta = self.metadata
users, addresses, dingalings = self.tables.users, \
self.tables.email_addresses, self.tables.dingalings
insp = inspect(meta.bind)
oid = insp.get_table_oid(table_name, schema)
self.assert_(isinstance(oid, (int, long)))
def test_get_table_oid(self):
self._test_get_table_oid('users')
@testing.requires.schemas
def test_get_table_oid_with_schema(self):
self._test_get_table_oid('users', schema='test_schema')
@testing.provide_metadata
def test_autoincrement_col(self):
"""test that 'autoincrement' is reflected according to sqla's policy.
Don't mark this test as unsupported for any backend !
(technically it fails with MySQL InnoDB since "id" comes before "id2")
A backend is better off not returning "autoincrement" at all,
instead of potentially returning "False" for an auto-incrementing
primary key column.
"""
meta = self.metadata
insp = inspect(meta.bind)
for tname, cname in [
('users', 'user_id'),
('email_addresses', 'address_id'),
('dingalings', 'dingaling_id'),
]:
cols = insp.get_columns(tname)
id_ = dict((c['name'], c) for c in cols)[cname]
assert id_.get('autoincrement', True)
__all__ = ('ComponentReflectionTest', 'HasTableTest')

View File

@@ -0,0 +1,69 @@
from .. import fixtures, config
from ..config import requirements
from .. import exclusions
from ..assertions import eq_
from .. import engines
from sqlalchemy import Integer, String, select, util
from ..schema import Table, Column
class RowFetchTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
Table('plain_pk', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(50))
)
@classmethod
def insert_data(cls):
config.db.execute(
cls.tables.plain_pk.insert(),
[
{"id":1, "data":"d1"},
{"id":2, "data":"d2"},
{"id":3, "data":"d3"},
]
)
def test_via_string(self):
row = config.db.execute(
self.tables.plain_pk.select().\
order_by(self.tables.plain_pk.c.id)
).first()
eq_(
row['id'], 1
)
eq_(
row['data'], "d1"
)
def test_via_int(self):
row = config.db.execute(
self.tables.plain_pk.select().\
order_by(self.tables.plain_pk.c.id)
).first()
eq_(
row[0], 1
)
eq_(
row[1], "d1"
)
def test_via_col_object(self):
row = config.db.execute(
self.tables.plain_pk.select().\
order_by(self.tables.plain_pk.c.id)
).first()
eq_(
row[self.tables.plain_pk.c.id], 1
)
eq_(
row[self.tables.plain_pk.c.data], "d1"
)

View File

@@ -0,0 +1,126 @@
from .. import fixtures, config
from ..config import requirements
from ..assertions import eq_
from ... import testing
from ... import Integer, String, Sequence, schema
from ..schema import Table, Column
class SequenceTest(fixtures.TablesTest):
__requires__ = ('sequences',)
run_create_tables = 'each'
@classmethod
def define_tables(cls, metadata):
Table('seq_pk', metadata,
Column('id', Integer, Sequence('tab_id_seq'), primary_key=True),
Column('data', String(50))
)
Table('seq_opt_pk', metadata,
Column('id', Integer, Sequence('tab_id_seq', optional=True),
primary_key=True),
Column('data', String(50))
)
def test_insert_roundtrip(self):
config.db.execute(
self.tables.seq_pk.insert(),
data="some data"
)
self._assert_round_trip(self.tables.seq_pk, config.db)
def test_insert_lastrowid(self):
r = config.db.execute(
self.tables.seq_pk.insert(),
data="some data"
)
eq_(
r.inserted_primary_key,
[1]
)
def test_nextval_direct(self):
r = config.db.execute(
self.tables.seq_pk.c.id.default
)
eq_(
r, 1
)
@requirements.sequences_optional
def test_optional_seq(self):
r = config.db.execute(
self.tables.seq_opt_pk.insert(),
data="some data"
)
eq_(
r.inserted_primary_key,
[1]
)
def _assert_round_trip(self, table, conn):
row = conn.execute(table.select()).first()
eq_(
row,
(1, "some data")
)
class HasSequenceTest(fixtures.TestBase):
__requires__ = 'sequences',
def test_has_sequence(self):
s1 = Sequence('user_id_seq')
testing.db.execute(schema.CreateSequence(s1))
try:
eq_(testing.db.dialect.has_sequence(testing.db,
'user_id_seq'), True)
finally:
testing.db.execute(schema.DropSequence(s1))
@testing.requires.schemas
def test_has_sequence_schema(self):
s1 = Sequence('user_id_seq', schema="test_schema")
testing.db.execute(schema.CreateSequence(s1))
try:
eq_(testing.db.dialect.has_sequence(testing.db,
'user_id_seq', schema="test_schema"), True)
finally:
testing.db.execute(schema.DropSequence(s1))
def test_has_sequence_neg(self):
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'),
False)
@testing.requires.schemas
def test_has_sequence_schemas_neg(self):
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq',
schema="test_schema"),
False)
@testing.requires.schemas
def test_has_sequence_default_not_in_remote(self):
s1 = Sequence('user_id_seq')
testing.db.execute(schema.CreateSequence(s1))
try:
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq',
schema="test_schema"),
False)
finally:
testing.db.execute(schema.DropSequence(s1))
@testing.requires.schemas
def test_has_sequence_remote_not_in_default(self):
s1 = Sequence('user_id_seq', schema="test_schema")
testing.db.execute(schema.CreateSequence(s1))
try:
eq_(testing.db.dialect.has_sequence(testing.db, 'user_id_seq'),
False)
finally:
testing.db.execute(schema.DropSequence(s1))

View File

@@ -0,0 +1,393 @@
# coding: utf-8
from .. import fixtures, config
from ..assertions import eq_
from ..config import requirements
from sqlalchemy import Integer, Unicode, UnicodeText, select
from sqlalchemy import Date, DateTime, Time, MetaData, String, \
Text, Numeric, Float
from ..schema import Table, Column
from ... import testing
import decimal
import datetime
class _UnicodeFixture(object):
__requires__ = 'unicode_data',
data = u"Alors vous imaginez ma surprise, au lever du jour, "\
u"quand une drôle de petite voix ma réveillé. Elle "\
u"disait: « Sil vous plaît… dessine-moi un mouton! »"
@classmethod
def define_tables(cls, metadata):
Table('unicode_table', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('unicode_data', cls.datatype),
)
def test_round_trip(self):
unicode_table = self.tables.unicode_table
config.db.execute(
unicode_table.insert(),
{
'unicode_data': self.data,
}
)
row = config.db.execute(
select([
unicode_table.c.unicode_data,
])
).first()
eq_(
row,
(self.data, )
)
assert isinstance(row[0], unicode)
def test_round_trip_executemany(self):
unicode_table = self.tables.unicode_table
config.db.execute(
unicode_table.insert(),
[
{
'unicode_data': self.data,
}
for i in xrange(3)
]
)
rows = config.db.execute(
select([
unicode_table.c.unicode_data,
])
).fetchall()
eq_(
rows,
[(self.data, ) for i in xrange(3)]
)
for row in rows:
assert isinstance(row[0], unicode)
def _test_empty_strings(self):
unicode_table = self.tables.unicode_table
config.db.execute(
unicode_table.insert(),
{"unicode_data": u''}
)
row = config.db.execute(
select([unicode_table.c.unicode_data])
).first()
eq_(row, (u'',))
class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
__requires__ = 'unicode_data',
datatype = Unicode(255)
@requirements.empty_strings_varchar
def test_empty_strings_varchar(self):
self._test_empty_strings()
class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
__requires__ = 'unicode_data', 'text_type'
datatype = UnicodeText()
@requirements.empty_strings_text
def test_empty_strings_text(self):
self._test_empty_strings()
class TextTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
Table('text_table', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('text_data', Text),
)
def test_text_roundtrip(self):
text_table = self.tables.text_table
config.db.execute(
text_table.insert(),
{"text_data": 'some text'}
)
row = config.db.execute(
select([text_table.c.text_data])
).first()
eq_(row, ('some text',))
def test_text_empty_strings(self):
text_table = self.tables.text_table
config.db.execute(
text_table.insert(),
{"text_data": ''}
)
row = config.db.execute(
select([text_table.c.text_data])
).first()
eq_(row, ('',))
class StringTest(fixtures.TestBase):
@requirements.unbounded_varchar
def test_nolength_string(self):
metadata = MetaData()
foo = Table('foo', metadata,
Column('one', String)
)
foo.create(config.db)
foo.drop(config.db)
class _DateFixture(object):
compare = None
@classmethod
def define_tables(cls, metadata):
Table('date_table', metadata,
Column('id', Integer, primary_key=True,
test_needs_autoincrement=True),
Column('date_data', cls.datatype),
)
def test_round_trip(self):
date_table = self.tables.date_table
config.db.execute(
date_table.insert(),
{'date_data': self.data}
)
row = config.db.execute(
select([
date_table.c.date_data,
])
).first()
compare = self.compare or self.data
eq_(row,
(compare, ))
assert isinstance(row[0], type(compare))
def test_null(self):
date_table = self.tables.date_table
config.db.execute(
date_table.insert(),
{'date_data': None}
)
row = config.db.execute(
select([
date_table.c.date_data,
])
).first()
eq_(row, (None,))
class DateTimeTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'datetime',
datatype = DateTime
data = datetime.datetime(2012, 10, 15, 12, 57, 18)
class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'datetime_microseconds',
datatype = DateTime
data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
class TimeTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'time',
datatype = Time
data = datetime.time(12, 57, 18)
class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'time_microseconds',
datatype = Time
data = datetime.time(12, 57, 18, 396)
class DateTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'date',
datatype = Date
data = datetime.date(2012, 10, 15)
class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'date',
datatype = Date
data = datetime.datetime(2012, 10, 15, 12, 57, 18)
compare = datetime.date(2012, 10, 15)
class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'datetime_historic',
datatype = DateTime
data = datetime.datetime(1850, 11, 10, 11, 52, 35)
class DateHistoricTest(_DateFixture, fixtures.TablesTest):
__requires__ = 'date_historic',
datatype = Date
data = datetime.date(1727, 4, 1)
class NumericTest(fixtures.TestBase):
@testing.emits_warning(r".*does \*not\* support Decimal objects natively")
@testing.provide_metadata
def _do_test(self, type_, input_, output, filter_=None, check_scale=False):
metadata = self.metadata
t = Table('t', metadata, Column('x', type_))
t.create()
t.insert().execute([{'x':x} for x in input_])
result = set([row[0] for row in t.select().execute()])
output = set(output)
if filter_:
result = set(filter_(x) for x in result)
output = set(filter_(x) for x in output)
eq_(result, output)
if check_scale:
eq_(
[str(x) for x in result],
[str(x) for x in output],
)
def test_numeric_as_decimal(self):
self._do_test(
Numeric(precision=8, scale=4),
[15.7563, decimal.Decimal("15.7563"), None],
[decimal.Decimal("15.7563"), None],
)
def test_numeric_as_float(self):
self._do_test(
Numeric(precision=8, scale=4, asdecimal=False),
[15.7563, decimal.Decimal("15.7563"), None],
[15.7563, None],
)
def test_float_as_decimal(self):
self._do_test(
Float(precision=8, asdecimal=True),
[15.7563, decimal.Decimal("15.7563"), None],
[decimal.Decimal("15.7563"), None],
)
def test_float_as_float(self):
self._do_test(
Float(precision=8),
[15.7563, decimal.Decimal("15.7563")],
[15.7563],
filter_=lambda n: n is not None and round(n, 5) or None
)
@testing.requires.precision_numerics_general
def test_precision_decimal(self):
numbers = set([
decimal.Decimal("54.234246451650"),
decimal.Decimal("0.004354"),
decimal.Decimal("900.0"),
])
self._do_test(
Numeric(precision=18, scale=12),
numbers,
numbers,
)
@testing.requires.precision_numerics_enotation_large
def test_enotation_decimal(self):
"""test exceedingly small decimals.
Decimal reports values with E notation when the exponent
is greater than 6.
"""
numbers = set([
decimal.Decimal('1E-2'),
decimal.Decimal('1E-3'),
decimal.Decimal('1E-4'),
decimal.Decimal('1E-5'),
decimal.Decimal('1E-6'),
decimal.Decimal('1E-7'),
decimal.Decimal('1E-8'),
decimal.Decimal("0.01000005940696"),
decimal.Decimal("0.00000005940696"),
decimal.Decimal("0.00000000000696"),
decimal.Decimal("0.70000000000696"),
decimal.Decimal("696E-12"),
])
self._do_test(
Numeric(precision=18, scale=14),
numbers,
numbers
)
@testing.requires.precision_numerics_enotation_large
def test_enotation_decimal_large(self):
"""test exceedingly large decimals.
"""
numbers = set([
decimal.Decimal('4E+8'),
decimal.Decimal("5748E+15"),
decimal.Decimal('1.521E+15'),
decimal.Decimal('00000000000000.1E+12'),
])
self._do_test(
Numeric(precision=25, scale=2),
numbers,
numbers
)
@testing.requires.precision_numerics_many_significant_digits
def test_many_significant_digits(self):
numbers = set([
decimal.Decimal("31943874831932418390.01"),
decimal.Decimal("319438950232418390.273596"),
decimal.Decimal("87673.594069654243"),
])
self._do_test(
Numeric(precision=38, scale=12),
numbers,
numbers
)
@testing.requires.precision_numerics_retains_significant_digits
def test_numeric_no_decimal(self):
numbers = set([
decimal.Decimal("1.000")
])
self._do_test(
Numeric(precision=5, scale=3),
numbers,
numbers,
check_scale=True
)
__all__ = ('UnicodeVarcharTest', 'UnicodeTextTest',
'DateTest', 'DateTimeTest', 'TextTest',
'NumericTest',
'DateTimeHistoricTest', 'DateTimeCoercedToDateTimeTest',
'TimeMicrosecondsTest', 'TimeTest', 'DateTimeMicrosecondsTest',
'DateHistoricTest', 'StringTest')

View File

@@ -0,0 +1,62 @@
from .. import fixtures, config
from ..assertions import eq_
from sqlalchemy import Integer, String
from ..schema import Table, Column
class SimpleUpdateDeleteTest(fixtures.TablesTest):
run_deletes = 'each'
@classmethod
def define_tables(cls, metadata):
Table('plain_pk', metadata,
Column('id', Integer, primary_key=True),
Column('data', String(50))
)
@classmethod
def insert_data(cls):
config.db.execute(
cls.tables.plain_pk.insert(),
[
{"id":1, "data":"d1"},
{"id":2, "data":"d2"},
{"id":3, "data":"d3"},
]
)
def test_update(self):
t = self.tables.plain_pk
r = config.db.execute(
t.update().where(t.c.id == 2),
data="d2_new"
)
assert not r.is_insert
assert not r.returns_rows
eq_(
config.db.execute(t.select().order_by(t.c.id)).fetchall(),
[
(1, "d1"),
(2, "d2_new"),
(3, "d3")
]
)
def test_delete(self):
t = self.tables.plain_pk
r = config.db.execute(
t.delete().where(t.c.id == 2)
)
assert not r.is_insert
assert not r.returns_rows
eq_(
config.db.execute(t.select().order_by(t.c.id)).fetchall(),
[
(1, "d1"),
(3, "d3")
]
)
__all__ = ('SimpleUpdateDeleteTest', )

View File

@@ -0,0 +1,199 @@
from ..util import jython, pypy, defaultdict, decorator
import decimal
import gc
import time
import random
import sys
import types
if jython:
def jython_gc_collect(*args):
"""aggressive gc.collect for tests."""
gc.collect()
time.sleep(0.1)
gc.collect()
gc.collect()
return 0
# "lazy" gc, for VM's that don't GC on refcount == 0
gc_collect = lazy_gc = jython_gc_collect
elif pypy:
def pypy_gc_collect(*args):
gc.collect()
gc.collect()
gc_collect = lazy_gc = pypy_gc_collect
else:
# assume CPython - straight gc.collect, lazy_gc() is a pass
gc_collect = gc.collect
def lazy_gc():
pass
def picklers():
picklers = set()
# Py2K
try:
import cPickle
picklers.add(cPickle)
except ImportError:
pass
# end Py2K
import pickle
picklers.add(pickle)
# yes, this thing needs this much testing
for pickle_ in picklers:
for protocol in -1, 0, 1, 2:
yield pickle_.loads, lambda d: pickle_.dumps(d, protocol)
def round_decimal(value, prec):
if isinstance(value, float):
return round(value, prec)
# can also use shift() here but that is 2.6 only
return (value * decimal.Decimal("1" + "0" * prec)
).to_integral(decimal.ROUND_FLOOR) / \
pow(10, prec)
class RandomSet(set):
def __iter__(self):
l = list(set.__iter__(self))
random.shuffle(l)
return iter(l)
def pop(self):
index = random.randint(0, len(self) - 1)
item = list(set.__iter__(self))[index]
self.remove(item)
return item
def union(self, other):
return RandomSet(set.union(self, other))
def difference(self, other):
return RandomSet(set.difference(self, other))
def intersection(self, other):
return RandomSet(set.intersection(self, other))
def copy(self):
return RandomSet(self)
def conforms_partial_ordering(tuples, sorted_elements):
"""True if the given sorting conforms to the given partial ordering."""
deps = defaultdict(set)
for parent, child in tuples:
deps[parent].add(child)
for i, node in enumerate(sorted_elements):
for n in sorted_elements[i:]:
if node in deps[n]:
return False
else:
return True
def all_partial_orderings(tuples, elements):
edges = defaultdict(set)
for parent, child in tuples:
edges[child].add(parent)
def _all_orderings(elements):
if len(elements) == 1:
yield list(elements)
else:
for elem in elements:
subset = set(elements).difference([elem])
if not subset.intersection(edges[elem]):
for sub_ordering in _all_orderings(subset):
yield [elem] + sub_ordering
return iter(_all_orderings(elements))
def function_named(fn, name):
"""Return a function with a given __name__.
Will assign to __name__ and return the original function if possible on
the Python implementation, otherwise a new function will be constructed.
This function should be phased out as much as possible
in favor of @decorator. Tests that "generate" many named tests
should be modernized.
"""
try:
fn.__name__ = name
except TypeError:
fn = types.FunctionType(fn.func_code, fn.func_globals, name,
fn.func_defaults, fn.func_closure)
return fn
def run_as_contextmanager(ctx, fn, *arg, **kw):
"""Run the given function under the given contextmanager,
simulating the behavior of 'with' to support older
Python versions.
"""
obj = ctx.__enter__()
try:
result = fn(obj, *arg, **kw)
ctx.__exit__(None, None, None)
return result
except:
exc_info = sys.exc_info()
raise_ = ctx.__exit__(*exc_info)
if raise_ is None:
raise
else:
return raise_
def rowset(results):
"""Converts the results of sql execution into a plain set of column tuples.
Useful for asserting the results of an unordered query.
"""
return set([tuple(row) for row in results])
def fail(msg):
assert False, msg
@decorator
def provide_metadata(fn, *args, **kw):
"""Provide bound MetaData for a single test, dropping afterwards."""
from . import config
from sqlalchemy import schema
metadata = schema.MetaData(config.db)
self = args[0]
prev_meta = getattr(self, 'metadata', None)
self.metadata = metadata
try:
return fn(*args, **kw)
finally:
metadata.drop_all()
self.metadata = prev_meta
class adict(dict):
"""Dict keys available as attributes. Shadows."""
def __getattribute__(self, key):
try:
return self[key]
except KeyError:
return dict.__getattribute__(self, key)
def get_all(self, *keys):
return tuple([self[key] for key in keys])

View File

@@ -0,0 +1,47 @@
from __future__ import absolute_import
import warnings
from .. import exc as sa_exc
from .. import util
def testing_warn(msg, stacklevel=3):
"""Replaces sqlalchemy.util.warn during tests."""
filename = "sqlalchemy.testing.warnings"
lineno = 1
if isinstance(msg, basestring):
warnings.warn_explicit(msg, sa_exc.SAWarning, filename, lineno)
else:
warnings.warn_explicit(msg, filename, lineno)
def resetwarnings():
"""Reset warning behavior to testing defaults."""
util.warn = util.langhelpers.warn = testing_warn
warnings.filterwarnings('ignore',
category=sa_exc.SAPendingDeprecationWarning)
warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
warnings.filterwarnings('error', category=sa_exc.SAWarning)
def assert_warnings(fn, warnings):
"""Assert that each of the given warnings are emitted by fn."""
from .assertions import eq_, emits_warning
canary = []
orig_warn = util.warn
def capture_warnings(*args, **kw):
orig_warn(*args, **kw)
popwarn = warnings.pop(0)
canary.append(popwarn)
eq_(args[0], popwarn)
util.warn = util.langhelpers.warn = capture_warnings
result = emits_warning()(fn)()
assert canary, "No warning was emitted"
return result