Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose TLS and authentication options in mongodb_uri field #287

Closed
wants to merge 10 commits into from
49 changes: 38 additions & 11 deletions mongo_orchestration/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@
import os
import ssl
import stat
import sys
import tempfile

if sys.version_info[0] == 2:
from urllib import quote_plus
else:
from urllib.parse import quote_plus

WORK_DIR = os.environ.get('MONGO_ORCHESTRATION_HOME', os.getcwd())
PID_FILE = os.path.join(WORK_DIR, 'server.pid')
LOG_FILE = os.path.join(WORK_DIR, 'server.log')
Expand Down Expand Up @@ -50,7 +56,11 @@
'ssl_certfile': DEFAULT_CLIENT_CERT,
'ssl_cert_reqs': ssl.CERT_NONE
}

SSL_TO_TLS_OPTION_MAPPINGS = {
'ssl': 'tls',
'ssl_certfile': 'tlsCertificateKeyFile',
'sslCAFile': 'tlsCAFile',
}

class BaseModel(object):
"""Base object for Server, ReplicaSet, and ShardedCluster."""
Expand Down Expand Up @@ -80,19 +90,36 @@ def _strip_auth(self, proc_params):
params.pop("clusterAuthMode", None)
return params

def mongodb_auth_uri(self, hosts):
"""Get a connection string with all info necessary to authenticate."""
parts = ['mongodb://']
def mongodb_uri(self, hosts, uri_opts):
"""Returns the connection string for the cluster"""
# Append TLS options
if self.ssl_params:
ssl_params = self.ssl_params.copy()
ssl_params.update(DEFAULT_SSL_OPTIONS)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modifies the self.ssl_params property. Why do we need to use DEFAULT_SSL_OPTIONS here? Isn't that already the default for self.ssl_params?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I wasn't aware that this modifies the reference. sslParams for example doesn't contain ssl=true, it only contains the parameters necessary for the server (see sample orchestration files). DEFAULT_SSL_OPTIONS is applied to self.kwargs but not self.ssl_params. What would be the right way to apply these default options?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh thanks for explaining that. In this case I don't think we should add any of the self.ssl_params to the client URI since those are ssl params for the server itself. I think we should do this:

        if self.ssl_params:  # Server ssl params.
            ssl_params = DEFAULT_SSL_OPTIONS.copy()  # Client ssl params
            ....

The above approach would need to add the tlsInsecure=true option to match the ssl_cert_reqs=ssl.CERT_NONE in DEFAULT_SSL_OPTIONS.

Alternatively we can keep your current approach but make a copy of self.ssl_params to avoid modifying it:

        if self.ssl_params:
            ssl_params = self.ssl_params.copy()
            ssl_params.update(DEFAULT_SSL_OPTIONS)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update to copy the params. DEFAULT_SSL_OPTIONS by itself is not enough, as there are additional settings that come into play. The new SSL_TO_TLS_OPTION_MAPPINGS variable indicates which of the fields from the configuration's sslParams and DEFAULT_SSL_OPTIONS map to URI options for SSL. I can add additional tests to ensure we're not adding unknown or wrong URI options if you'd like.

Copy link
Collaborator

@ShaneHarvey ShaneHarvey Apr 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.ssl_params.copy() SGTM

# Rewrite ssl* option names to tls*
for sslKey, tlsKey in SSL_TO_TLS_OPTION_MAPPINGS.items():
sslValue = ssl_params.pop(sslKey)
if sslValue:
if isinstance(sslValue, bool):
sslValue = json.dumps(sslValue)
uri_opts.append(tlsKey + '=' + quote_plus(sslValue))

# Append Auth options
auth_opts = []
if self.login:
parts.append(self.login)
auth_opts.append(self.login)
if self.password:
parts.append(':' + self.password)
parts.append('@')
parts.append(hosts + '/')
if self.login:
parts.append('?authSource=' + self.auth_source)
auth_opts.append(':' + self.password)
auth_opts.append('@')
uri_opts.append('authSource=' + self.auth_source)
if self.x509_extra_user:
parts.append('&authMechanism=MONGODB-X509')
uri_opts.append('authMechanism=MONGODB-X509')

parts = ['mongodb://', ''.join(auth_opts), hosts, '/']

if len(uri_opts) > 0:
parts.append('?' + '&'.join(uri_opts))

return ''.join(parts)

def _get_server_version(self, client):
Expand Down
17 changes: 7 additions & 10 deletions mongo_orchestration/replica_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ def __init__(self, rs_params):
self.repl_id = rs_params.get('id', None) or str(uuid4())
self._version = rs_params.get('version')

self.sslParams = rs_params.get('sslParams', {})
self.ssl_params = rs_params.get('sslParams', {})
self.kwargs = {}
self.restart_required = self.login or self.auth_key
self.x509_extra_user = False

if self.sslParams:
if self.ssl_params:
self.kwargs.update(DEFAULT_SSL_OPTIONS)

members = rs_params.get('members', [])
Expand Down Expand Up @@ -229,17 +229,15 @@ def repl_update(self, config):
def info(self):
"""return information about replica set"""
hosts = ','.join(x['host'] for x in self.members())
mongodb_uri = 'mongodb://' + hosts + '/?replicaSet=' + self.repl_id
uri_opts = ['replicaSet=' + self.repl_id]
mongodb_uri = self.mongodb_uri(hosts, uri_opts)
result = {"id": self.repl_id,
"auth_key": self.auth_key,
"members": self.members(),
"mongodb_uri": mongodb_uri,
"orchestration": 'replica_sets'}
if self.login:
# Add replicaSet URI parameter.
uri = ('%s&replicaSet=%s'
% (self.mongodb_auth_uri(hosts), self.repl_id))
result['mongodb_auth_uri'] = uri
result['mongodb_auth_uri'] = result['mongodb_uri']
return result

def repl_member_add(self, params):
Expand Down Expand Up @@ -312,7 +310,7 @@ def member_create(self, params, member_id):
server_id = self._servers.create(
name='mongod',
procParams=proc_params,
sslParams=self.sslParams,
sslParams=self.ssl_params,
version=version,
server_id=server_id
)
Expand Down Expand Up @@ -359,8 +357,7 @@ def member_info(self, member_id):
'procInfo': server_info['procInfo'],
'statuses': server_info['statuses']}
if self.login:
result['mongodb_auth_uri'] = self.mongodb_auth_uri(
self._servers.hostname(server_id))
result['mongodb_auth_uri'] = result['mongodb_uri']
result['rsInfo'] = {}
if server_info['procInfo']['alive']:
# Can't call serverStatus on arbiter when running with auth enabled.
Expand Down
4 changes: 2 additions & 2 deletions mongo_orchestration/servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def info(self):
c = self.connection
server_info = c.server_info()
logger.debug("server_info: {server_info}".format(**locals()))
mongodb_uri = 'mongodb://' + self.hostname
mongodb_uri = self.mongodb_uri(self.hostname, [])
status_info = {"primary": c.is_primary, "mongos": c.is_mongos}
logger.debug("status_info: {status_info}".format(**locals()))
except (pymongo.errors.AutoReconnect, pymongo.errors.OperationFailure, pymongo.errors.ConnectionFailure):
Expand All @@ -311,7 +311,7 @@ def info(self):
"serverInfo": server_info, "procInfo": proc_info,
"orchestration": 'servers'}
if self.login:
result['mongodb_auth_uri'] = self.mongodb_auth_uri(self.hostname)
result['mongodb_auth_uri'] = mongodb_uri
logger.debug("return {result}".format(result=result))
return result

Expand Down
22 changes: 11 additions & 11 deletions mongo_orchestration/sharded_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ def __init__(self, params):
self._shards = {}
self.tags = {}

self.sslParams = params.get('sslParams', {})
self.ssl_params = params.get('sslParams', {})
self.kwargs = {}
self.restart_required = self.login or self.auth_key
self.x509_extra_user = False

if self.sslParams:
if self.ssl_params:
self.kwargs.update(DEFAULT_SSL_OPTIONS)

self.enable_ipv6 = common.ipv6_enabled_sharded(params)
Expand Down Expand Up @@ -172,7 +172,7 @@ def only_x509(config):
def restart_with_auth(server_or_rs):
server_or_rs.x509_extra_user = self.x509_extra_user
server_or_rs.auth_source = self.auth_source
server_or_rs.ssl_params = self.sslParams
server_or_rs.ssl_params = self.ssl_params
server_or_rs.login = self.login
server_or_rs.password = self.password
server_or_rs.auth_key = self.auth_key
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init_configrs(self, rs_cfg):
member['procParams']['configsvr'] = True
if self.enable_ipv6:
common.enable_ipv6_single(member['procParams'])
rs_cfg['sslParams'] = self.sslParams
rs_cfg['sslParams'] = self.ssl_params
self._configsvrs.append(ReplicaSets().create(rs_cfg))

def __init_configsvrs(self, params):
Expand All @@ -229,7 +229,7 @@ def __init_configsvrs(self, params):
if self.enable_ipv6:
common.enable_ipv6_single(cfg)
self._configsvrs.append(Servers().create(
'mongod', cfg, sslParams=self.sslParams, autostart=True,
'mongod', cfg, sslParams=self.ssl_params, autostart=True,
version=version, server_id=server_id))

def __len__(self):
Expand Down Expand Up @@ -285,7 +285,7 @@ def router_add(self, params):
params = self._strip_auth(params)

self._routers.append(Servers().create(
'mongos', params, sslParams=self.sslParams, autostart=True,
'mongos', params, sslParams=self.ssl_params, autostart=True,
version=version, server_id=server_id))
return {'id': self._routers[-1], 'hostname': Servers().hostname(self._routers[-1])}

Expand Down Expand Up @@ -362,7 +362,7 @@ def member_add(self, member_id=None, params=None):
rs_params = params.copy()
# Turn 'rs_id' -> 'id', to be consistent with 'server_id' below.
rs_params['id'] = rs_params.pop('rs_id', None)
rs_params.update({'sslParams': self.sslParams})
rs_params.update({'sslParams': self.ssl_params})

rs_params['version'] = params.pop('version', self._version)
rs_params['members'] = [
Expand All @@ -379,7 +379,7 @@ def member_add(self, member_id=None, params=None):
else:
# is single server
params.setdefault('procParams', {})['shardsvr'] = True
params.update({'autostart': True, 'sslParams': self.sslParams})
params.update({'autostart': True, 'sslParams': self.ssl_params})
params = params.copy()
params['procParams'] = self._strip_auth(
params.get('procParams', {}))
Expand Down Expand Up @@ -432,16 +432,16 @@ def reset(self):

def info(self):
"""return info about configuration"""
uri = ','.join(x['hostname'] for x in self.routers)
mongodb_uri = 'mongodb://' + uri
hosts = ','.join(x['hostname'] for x in self.routers)
mongodb_uri = self.mongodb_uri(hosts, [])
result = {'id': self.id,
'shards': self.members,
'configsvrs': self.configsvrs,
'routers': self.routers,
'mongodb_uri': mongodb_uri,
'orchestration': 'sharded_clusters'}
if self.login:
result['mongodb_auth_uri'] = self.mongodb_auth_uri(uri)
result['mongodb_auth_uri'] = mongodb_uri
return result

def cleanup(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
version_str = re.search('((\d+\.)+\d+)', info['version']).group(0)
SERVER_VERSION = tuple(map(int, version_str.split('.')))
# Do we have SSL support?
SSL_ENABLED = bool(info.get('OpenSSLVersion'))
SSL_ENABLED = bool(info.get('OpenSSLVersion')) or bool(info.get('openssl'))
finally:
Servers().cleanup()

Expand Down
6 changes: 6 additions & 0 deletions tests/test_replica_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,12 @@ def test_ssl(self):
self.repl.primary(), ssl_certfile=certificate('client.pem'),
ssl_cert_reqs=ssl.CERT_NONE))

# Check if mongodb_uri contains tls args
uri = self.server.info()['mongodb_uri']
self.assertIn(self.server.hostname, uri)
self.assertIn('tls=', uri)
self.assertIn('tlsCertificateKeyFile=', uri)

def test_mongodb_auth_uri(self):
if SERVER_VERSION < (2, 4):
raise SkipTest("Need to be able to set 'authenticationMechanisms' "
Expand Down
10 changes: 10 additions & 0 deletions tests/test_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import operator
import os
import socket
Expand All @@ -35,6 +36,9 @@
SkipTest, certificate, unittest, TEST_SUBJECT, SSLTestCase, SERVER_VERSION,
TEST_RELEASES)

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


class ServerVersionTestCase(unittest.TestCase):
def _test_version_parse(self, version_str, expected_version):
Expand Down Expand Up @@ -459,6 +463,12 @@ def test_ssl(self):
self.server.hostname, ssl_certfile=certificate('client.pem'),
ssl_cert_reqs=ssl.CERT_NONE))

# Check if mongodb_uri contains tls args
uri = self.server.info()['mongodb_uri']
self.assertIn(self.server.hostname, uri)
self.assertIn('tls=', uri)
self.assertIn('tlsCertificateKeyFile=', uri)

def test_mongodb_auth_uri(self):
if SERVER_VERSION < (2, 4):
raise SkipTest("Need to be able to set 'authenticationMechanisms' "
Expand Down
6 changes: 6 additions & 0 deletions tests/test_sharded_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,12 @@ def test_ssl(self):
pymongo.MongoClient(host, ssl_certfile=certificate('client.pem'),
ssl_cert_reqs=ssl.CERT_NONE))

# Check if mongodb_uri contains tls args
uri = self.server.info()['mongodb_uri']
self.assertIn(self.server.hostname, uri)
self.assertIn('tls=', uri)
self.assertIn('tlsCertificateKeyFile=', uri)

def test_mongodb_auth_uri(self):
if SERVER_VERSION < (2, 4):
raise SkipTest("Need to be able to set 'authenticationMechanisms' "
Expand Down