Commit cef93db0 by Brian Coca

Merge pull request #10754 from invenia/devel

Python 2/3 compatibility fixes to parsing in v2.
parents 6a35463e 3e25f633
......@@ -20,7 +20,6 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from six import iteritems, string_types
from types import NoneType
from ansible.errors import AnsibleParserError
from ansible.plugins import module_loader
......@@ -165,7 +164,7 @@ class ModuleArgsParser:
# form is like: local_action: copy src=a dest=b ... pretty common
check_raw = action in ('command', 'shell', 'script')
args = parse_kv(thing, check_raw=check_raw)
elif isinstance(thing, NoneType):
elif thing is None:
# this can happen with modules which take no params, like ping:
args = None
else:
......
......@@ -19,14 +19,17 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
class AnsibleBaseYAMLObject:
from six import text_type
class AnsibleBaseYAMLObject(object):
'''
the base class used to sub-class python built-in objects
so that we can add attributes to them during yaml parsing
'''
_data_source = None
_line_number = 0
_data_source = None
_line_number = 0
_column_number = 0
def _get_ansible_position(self):
......@@ -36,21 +39,27 @@ class AnsibleBaseYAMLObject:
try:
(src, line, col) = obj
except (TypeError, ValueError):
raise AssertionError('ansible_pos can only be set with a tuple/list of three values: source, line number, column number')
self._data_source = src
self._line_number = line
raise AssertionError(
'ansible_pos can only be set with a tuple/list '
'of three values: source, line number, column number'
)
self._data_source = src
self._line_number = line
self._column_number = col
ansible_pos = property(_get_ansible_position, _set_ansible_position)
class AnsibleMapping(AnsibleBaseYAMLObject, dict):
''' sub class for dictionaries '''
pass
class AnsibleUnicode(AnsibleBaseYAMLObject, unicode):
class AnsibleUnicode(AnsibleBaseYAMLObject, text_type):
''' sub class for unicode objects '''
pass
class AnsibleSequence(AnsibleBaseYAMLObject, list):
''' sub class for lists '''
pass
......@@ -19,6 +19,8 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from six import string_types, text_type, binary_type, PY3
# to_bytes and to_unicode were written by Toshio Kuratomi for the
# python-kitchen library https://pypi.python.org/pypi/kitchen
# They are licensed in kitchen under the terms of the GPLv2+
......@@ -35,6 +37,9 @@ _LATIN1_ALIASES = frozenset(('latin-1', 'LATIN-1', 'latin1', 'LATIN1',
# EXCEPTION_CONVERTERS is defined below due to using to_unicode
if PY3:
basestring = (str, bytes)
def to_unicode(obj, encoding='utf-8', errors='replace', nonstring=None):
'''Convert an object into a :class:`unicode` string
......@@ -89,12 +94,12 @@ def to_unicode(obj, encoding='utf-8', errors='replace', nonstring=None):
# Could use isbasestring/isunicode here but we want this code to be as
# fast as possible
if isinstance(obj, basestring):
if isinstance(obj, unicode):
if isinstance(obj, text_type):
return obj
if encoding in _UTF8_ALIASES:
return unicode(obj, 'utf-8', errors)
return text_type(obj, 'utf-8', errors)
if encoding in _LATIN1_ALIASES:
return unicode(obj, 'latin-1', errors)
return text_type(obj, 'latin-1', errors)
return obj.decode(encoding, errors)
if not nonstring:
......@@ -110,19 +115,19 @@ def to_unicode(obj, encoding='utf-8', errors='replace', nonstring=None):
simple = None
if not simple:
try:
simple = str(obj)
simple = text_type(obj)
except UnicodeError:
try:
simple = obj.__str__()
except (UnicodeError, AttributeError):
simple = u''
if isinstance(simple, str):
return unicode(simple, encoding, errors)
if isinstance(simple, binary_type):
return text_type(simple, encoding, errors)
return simple
elif nonstring in ('repr', 'strict'):
obj_repr = repr(obj)
if isinstance(obj_repr, str):
obj_repr = unicode(obj_repr, encoding, errors)
if isinstance(obj_repr, binary_type):
obj_repr = text_type(obj_repr, encoding, errors)
if nonstring == 'repr':
return obj_repr
raise TypeError('to_unicode was given "%(obj)s" which is neither'
......@@ -198,19 +203,19 @@ def to_bytes(obj, encoding='utf-8', errors='replace', nonstring=None):
# Could use isbasestring, isbytestring here but we want this to be as fast
# as possible
if isinstance(obj, basestring):
if isinstance(obj, str):
if isinstance(obj, binary_type):
return obj
return obj.encode(encoding, errors)
if not nonstring:
nonstring = 'simplerepr'
if nonstring == 'empty':
return ''
return b''
elif nonstring == 'passthru':
return obj
elif nonstring == 'simplerepr':
try:
simple = str(obj)
simple = binary_type(obj)
except UnicodeError:
try:
simple = obj.__str__()
......@@ -220,19 +225,19 @@ def to_bytes(obj, encoding='utf-8', errors='replace', nonstring=None):
try:
simple = obj.__unicode__()
except (AttributeError, UnicodeError):
simple = ''
if isinstance(simple, unicode):
simple = b''
if isinstance(simple, text_type):
simple = simple.encode(encoding, 'replace')
return simple
elif nonstring in ('repr', 'strict'):
try:
obj_repr = obj.__repr__()
except (AttributeError, UnicodeError):
obj_repr = ''
if isinstance(obj_repr, unicode):
obj_repr = b''
if isinstance(obj_repr, text_type):
obj_repr = obj_repr.encode(encoding, errors)
else:
obj_repr = str(obj_repr)
obj_repr = binary_type(obj_repr)
if nonstring == 'repr':
return obj_repr
raise TypeError('to_bytes was given "%(obj)s" which is neither'
......
......@@ -19,6 +19,7 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from six import PY2
from yaml.scanner import ScannerError
from ansible.compat.tests import unittest
......@@ -79,6 +80,11 @@ class TestDataLoaderWithVault(unittest.TestCase):
3135306561356164310a343937653834643433343734653137383339323330626437313562306630
3035
"""
with patch('__builtin__.open', mock_open(read_data=vaulted_data)):
if PY2:
builtins_name = '__builtin__'
else:
builtins_name = 'builtins'
with patch(builtins_name + '.open', mock_open(read_data=vaulted_data)):
output = self._loader.load_from_file('dummy_vault.txt')
self.assertEqual(output, dict(foo='bar'))
......@@ -24,11 +24,14 @@ import os
import shutil
import time
import tempfile
import six
from binascii import unhexlify
from binascii import hexlify
from nose.plugins.skip import SkipTest
from ansible.compat.tests import unittest
from ansible.utils.unicode import to_bytes, to_unicode
from ansible import errors
from ansible.parsing.vault import VaultLib
......@@ -63,13 +66,13 @@ class TestVaultLib(unittest.TestCase):
'decrypt',
'_add_header',
'_split_header',]
for slot in slots:
for slot in slots:
assert hasattr(v, slot), "VaultLib is missing the %s method" % slot
def test_is_encrypted(self):
v = VaultLib(None)
assert not v.is_encrypted("foobar"), "encryption check on plaintext failed"
data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify("ansible")
assert not v.is_encrypted(u"foobar"), "encryption check on plaintext failed"
data = u"$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(b"ansible")
assert v.is_encrypted(data), "encryption check on headered text failed"
def test_add_header(self):
......@@ -77,22 +80,22 @@ class TestVaultLib(unittest.TestCase):
v.cipher_name = "TEST"
sensitive_data = "ansible"
data = v._add_header(sensitive_data)
lines = data.split('\n')
lines = data.split(b'\n')
assert len(lines) > 1, "failed to properly add header"
header = lines[0]
header = to_unicode(lines[0])
assert header.endswith(';TEST'), "header does end with cipher name"
header_parts = header.split(';')
assert len(header_parts) == 3, "header has the wrong number of parts"
assert len(header_parts) == 3, "header has the wrong number of parts"
assert header_parts[0] == '$ANSIBLE_VAULT', "header does not start with $ANSIBLE_VAULT"
assert header_parts[1] == v.version, "header version is incorrect"
assert header_parts[2] == 'TEST', "header does end with cipher name"
def test_split_header(self):
v = VaultLib('ansible')
data = "$ANSIBLE_VAULT;9.9;TEST\nansible"
rdata = v._split_header(data)
lines = rdata.split('\n')
assert lines[0] == "ansible"
data = b"$ANSIBLE_VAULT;9.9;TEST\nansible"
rdata = v._split_header(data)
lines = rdata.split(b'\n')
assert lines[0] == b"ansible"
assert v.cipher_name == 'TEST', "cipher name was not set"
assert v.version == "9.9"
......@@ -100,11 +103,11 @@ class TestVaultLib(unittest.TestCase):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2:
raise SkipTest
v = VaultLib('ansible')
v.cipher_name = 'AES'
v.cipher_name = u'AES'
enc_data = v.encrypt("foobar")
dec_data = v.decrypt(enc_data)
assert enc_data != "foobar", "encryption failed"
assert dec_data == "foobar", "decryption failed"
assert dec_data == "foobar", "decryption failed"
def test_encrypt_decrypt_aes256(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2:
......@@ -114,20 +117,20 @@ class TestVaultLib(unittest.TestCase):
enc_data = v.encrypt("foobar")
dec_data = v.decrypt(enc_data)
assert enc_data != "foobar", "encryption failed"
assert dec_data == "foobar", "decryption failed"
assert dec_data == "foobar", "decryption failed"
def test_encrypt_encrypted(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2:
raise SkipTest
v = VaultLib('ansible')
v.cipher_name = 'AES'
data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify("ansible")
data = "$ANSIBLE_VAULT;9.9;TEST\n%s" % hexlify(six.b("ansible"))
error_hit = False
try:
enc_data = v.encrypt(data)
except errors.AnsibleError as e:
error_hit = True
assert error_hit, "No error was thrown when trying to encrypt data with a header"
assert error_hit, "No error was thrown when trying to encrypt data with a header"
def test_decrypt_decrypted(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2:
......@@ -139,7 +142,7 @@ class TestVaultLib(unittest.TestCase):
dec_data = v.decrypt(data)
except errors.AnsibleError as e:
error_hit = True
assert error_hit, "No error was thrown when trying to decrypt data without a header"
assert error_hit, "No error was thrown when trying to decrypt data without a header"
def test_cipher_not_set(self):
# not setting the cipher should default to AES256
......@@ -152,5 +155,5 @@ class TestVaultLib(unittest.TestCase):
enc_data = v.encrypt(data)
except errors.AnsibleError as e:
error_hit = True
assert not error_hit, "An error was thrown when trying to encrypt data without the cipher set"
assert v.cipher_name == "AES256", "cipher name is not set to AES256: %s" % v.cipher_name
assert not error_hit, "An error was thrown when trying to encrypt data without the cipher set"
assert v.cipher_name == "AES256", "cipher name is not set to AES256: %s" % v.cipher_name
......@@ -21,6 +21,7 @@ from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
#!/usr/bin/env python
import sys
import getpass
import os
import shutil
......@@ -32,6 +33,7 @@ from nose.plugins.skip import SkipTest
from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch
from ansible.utils.unicode import to_bytes, to_unicode
from ansible import errors
from ansible.parsing.vault import VaultLib
......@@ -88,12 +90,12 @@ class TestVaultEditor(unittest.TestCase):
'read_data',
'write_data',
'shuffle_files']
for slot in slots:
for slot in slots:
assert hasattr(v, slot), "VaultLib is missing the %s method" % slot
@patch.object(VaultEditor, '_editor_shell_command')
def test_create_file(self, mock_editor_shell_command):
def sc_side_effect(filename):
return ['touch', filename]
mock_editor_shell_command.side_effect = sc_side_effect
......@@ -107,12 +109,16 @@ class TestVaultEditor(unittest.TestCase):
self.assertTrue(os.path.exists(tmp_file.name))
def test_decrypt_1_0(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2:
"""
Skip testing decrypting 1.0 files if we don't have access to AES, KDF or
Counter, or we are running on python3 since VaultAES hasn't been backported.
"""
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or sys.version > '3':
raise SkipTest
v10_file = tempfile.NamedTemporaryFile(delete=False)
with v10_file as f:
f.write(v10_data)
f.write(to_bytes(v10_data))
ve = VaultEditor(None, "ansible", v10_file.name)
......@@ -125,13 +131,13 @@ class TestVaultEditor(unittest.TestCase):
# verify decrypted content
f = open(v10_file.name, "rb")
fdata = f.read()
fdata = to_unicode(f.read())
f.close()
os.unlink(v10_file.name)
assert error_hit == False, "error decrypting 1.0 file"
assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip()
assert error_hit == False, "error decrypting 1.0 file"
assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip()
def test_decrypt_1_1(self):
......@@ -140,7 +146,7 @@ class TestVaultEditor(unittest.TestCase):
v11_file = tempfile.NamedTemporaryFile(delete=False)
with v11_file as f:
f.write(v11_data)
f.write(to_bytes(v11_data))
ve = VaultEditor(None, "ansible", v11_file.name)
......@@ -153,28 +159,32 @@ class TestVaultEditor(unittest.TestCase):
# verify decrypted content
f = open(v11_file.name, "rb")
fdata = f.read()
fdata = to_unicode(f.read())
f.close()
os.unlink(v11_file.name)
assert error_hit == False, "error decrypting 1.0 file"
assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip()
assert error_hit == False, "error decrypting 1.0 file"
assert fdata.strip() == "foo", "incorrect decryption of 1.0 file: %s" % fdata.strip()
def test_rekey_migration(self):
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2:
"""
Skip testing rekeying files if we don't have access to AES, KDF or
Counter, or we are running on python3 since VaultAES hasn't been backported.
"""
if not HAS_AES or not HAS_COUNTER or not HAS_PBKDF2 or sys.version > '3':
raise SkipTest
v10_file = tempfile.NamedTemporaryFile(delete=False)
with v10_file as f:
f.write(v10_data)
f.write(to_bytes(v10_data))
ve = VaultEditor(None, "ansible", v10_file.name)
# make sure the password functions for the cipher
error_hit = False
try:
try:
ve.rekey_file('ansible2')
except errors.AnsibleError as e:
error_hit = True
......@@ -184,7 +194,7 @@ class TestVaultEditor(unittest.TestCase):
fdata = f.read()
f.close()
assert error_hit == False, "error rekeying 1.0 file to 1.1"
assert error_hit == False, "error rekeying 1.0 file to 1.1"
# ensure filedata can be decrypted, is 1.1 and is AES256
vl = VaultLib("ansible2")
......@@ -198,7 +208,7 @@ class TestVaultEditor(unittest.TestCase):
os.unlink(v10_file.name)
assert vl.cipher_name == "AES256", "wrong cipher name set after rekey: %s" % vl.cipher_name
assert error_hit == False, "error decrypting migrated 1.0 file"
assert error_hit == False, "error decrypting migrated 1.0 file"
assert dec_data.strip() == "foo", "incorrect decryption of rekeyed/migrated file: %s" % dec_data
......@@ -20,6 +20,7 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
from six import text_type, binary_type
from six.moves import StringIO
from collections import Sequence, Set, Mapping
......@@ -28,6 +29,7 @@ from ansible.compat.tests.mock import patch
from ansible.parsing.yaml.loader import AnsibleLoader
class TestAnsibleLoaderBasic(unittest.TestCase):
def setUp(self):
......@@ -52,7 +54,7 @@ class TestAnsibleLoaderBasic(unittest.TestCase):
loader = AnsibleLoader(stream, 'myfile.yml')
data = loader.get_single_data()
self.assertEqual(data, u'Ansible')
self.assertIsInstance(data, unicode)
self.assertIsInstance(data, text_type)
self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17))
......@@ -63,7 +65,7 @@ class TestAnsibleLoaderBasic(unittest.TestCase):
loader = AnsibleLoader(stream, 'myfile.yml')
data = loader.get_single_data()
self.assertEqual(data, u'Cafè Eñyei')
self.assertIsInstance(data, unicode)
self.assertIsInstance(data, text_type)
self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17))
......@@ -76,8 +78,8 @@ class TestAnsibleLoaderBasic(unittest.TestCase):
data = loader.get_single_data()
self.assertEqual(data, {'webster': 'daniel', 'oed': 'oxford'})
self.assertEqual(len(data), 2)
self.assertIsInstance(data.keys()[0], unicode)
self.assertIsInstance(data.values()[0], unicode)
self.assertIsInstance(list(data.keys())[0], text_type)
self.assertIsInstance(list(data.values())[0], text_type)
# Beginning of the first key
self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17))
......@@ -94,7 +96,7 @@ class TestAnsibleLoaderBasic(unittest.TestCase):
data = loader.get_single_data()
self.assertEqual(data, [u'a', u'b'])
self.assertEqual(len(data), 2)
self.assertIsInstance(data[0], unicode)
self.assertIsInstance(data[0], text_type)
self.assertEqual(data.ansible_pos, ('myfile.yml', 2, 17))
......@@ -204,10 +206,10 @@ class TestAnsibleLoaderPlay(unittest.TestCase):
def walk(self, data):
# Make sure there's no str in the data
self.assertNotIsInstance(data, str)
self.assertNotIsInstance(data, binary_type)
# Descend into various container types
if isinstance(data, unicode):
if isinstance(data, text_type):
# strings are a sequence so we have to be explicit here
return
elif isinstance(data, (Sequence, Set)):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment