Commit acc8e3eb authored by Mun Yong Jang's avatar Mun Yong Jang Committed by Commit Bot

[presubmit] Extend depot tools auth to use luci context

Bug: 509672
Change-Id: Ie3cb2fa1a2276f1fe658cdf7b9ffb657d03556e8
Reviewed-on: https://chromium-review.googlesource.com/754340
Commit-Queue: Mun Yong Jang <myjang@google.com>
Reviewed-by: 's avatarNodir Turakulov <nodir@chromium.org>
parent 7d9d9233
...@@ -16,6 +16,7 @@ import os ...@@ -16,6 +16,7 @@ import os
import socket import socket
import sys import sys
import threading import threading
import time
import urllib import urllib
import urlparse import urlparse
import webbrowser import webbrowser
...@@ -102,6 +103,119 @@ class LoginRequiredError(AuthenticationError): ...@@ -102,6 +103,119 @@ class LoginRequiredError(AuthenticationError):
super(LoginRequiredError, self).__init__(msg) super(LoginRequiredError, self).__init__(msg)
class LuciContextAuthError(Exception):
"""Raised on errors related to unsuccessful attempts to load LUCI_CONTEXT"""
def get_luci_context_access_token():
"""Returns a valid AccessToken from the local LUCI context auth server.
Adapted from
https://chromium.googlesource.com/infra/luci/luci-py/+/master/client/libs/luci_context/luci_context.py
See the link above for more details.
Returns:
AccessToken if LUCI_CONTEXT is present and attempt to load it is successful.
None if LUCI_CONTEXT is absent.
Raises:
LuciContextAuthError if the attempt to load LUCI_CONTEXT
and request its access token is unsuccessful.
"""
return _get_luci_context_access_token(os.environ, datetime.datetime.utcnow())
def _get_luci_context_access_token(env, now):
ctx_path = env.get('LUCI_CONTEXT')
if not ctx_path:
return None
ctx_path = ctx_path.decode(sys.getfilesystemencoding())
logging.debug('Loading LUCI_CONTEXT: %r', ctx_path)
def authErr(msg, *args):
error_msg = msg % args
ex = sys.exc_info()[1]
if not ex:
logging.error(error_msg)
raise LuciContextAuthError(error_msg)
logging.exception(error_msg)
raise LuciContextAuthError('%s: %s' % (error_msg, ex))
try:
loaded = _load_luci_context(ctx_path)
except (OSError, IOError, ValueError):
authErr('Failed to open, read or decode LUCI_CONTEXT')
try:
local_auth = loaded.get('local_auth')
except AttributeError:
authErr('LUCI_CONTEXT not in proper format')
# failed to grab local_auth from LUCI context
if not local_auth:
logging.debug('local_auth: no local auth found')
return None
try:
account_id = local_auth.get('default_account_id')
secret = local_auth.get('secret')
rpc_port = int(local_auth.get('rpc_port'))
except (AttributeError, ValueError):
authErr('local_auth: unexpected local auth format')
if not secret:
authErr('local_auth: no secret returned')
# if account_id not specified, LUCI_CONTEXT should not be picked up
if not account_id:
return None
logging.debug('local_auth: requesting an access token for account "%s"',
account_id)
http = httplib2.Http()
host = '127.0.0.1:%d' % rpc_port
resp, content = http.request(
uri='http://%s/rpc/LuciLocalAuthService.GetOAuthToken' % host,
method='POST',
body=json.dumps({
'account_id': account_id,
'scopes': OAUTH_SCOPES.split(' '),
'secret': secret,
}),
headers={'Content-Type': 'application/json'})
if resp.status != 200:
err = ('local_auth: Failed to grab access token from '
'LUCI context server with status %d: %r')
authErr(err, resp.status, content)
try:
token = json.loads(content)
error_code = token.get('error_code')
error_message = token.get('error_message')
access_token = token.get('access_token')
expiry = token.get('expiry')
except (AttributeError, ValueError):
authErr('local_auth: Unexpected access token response format')
if error_code:
authErr('local_auth: Error %d in retrieving access token: %s',
error_code, error_message)
if not access_token:
authErr('local_auth: No access token returned from LUCI context server')
expiry_dt = None
if expiry:
try:
expiry_dt = datetime.datetime.utcfromtimestamp(expiry)
except (TypeError, ValueError):
authErr('Invalid expiry in returned token')
logging.debug(
'local_auth: got an access token for account "%s" that expires in %d sec',
account_id, expiry - time.mktime(now.timetuple()))
access_token = AccessToken(access_token, expiry_dt)
if _needs_refresh(access_token, now=now):
authErr('local_auth: the returned access token needs to be refreshed')
return access_token
def _load_luci_context(ctx_path):
with open(ctx_path) as f:
return json.load(f)
def make_auth_config( def make_auth_config(
use_oauth2=None, use_oauth2=None,
save_cookies=None, save_cookies=None,
...@@ -219,6 +333,9 @@ def get_authenticator_for_host(hostname, config): ...@@ -219,6 +333,9 @@ def get_authenticator_for_host(hostname, config):
Returns: Returns:
Authenticator object. Authenticator object.
Raises:
AuthenticationError if hostname is invalid.
""" """
hostname = hostname.lower().rstrip('/') hostname = hostname.lower().rstrip('/')
# Append some scheme, otherwise urlparse puts hostname into parsed.path. # Append some scheme, otherwise urlparse puts hostname into parsed.path.
...@@ -303,23 +420,43 @@ class Authenticator(object): ...@@ -303,23 +420,43 @@ class Authenticator(object):
with self._lock: with self._lock:
return bool(self._get_cached_credentials()) return bool(self._get_cached_credentials())
def get_access_token(self, force_refresh=False, allow_user_interaction=False): def get_access_token(self, force_refresh=False, allow_user_interaction=False,
use_local_auth=True):
"""Returns AccessToken, refreshing it if necessary. """Returns AccessToken, refreshing it if necessary.
Args: Args:
force_refresh: forcefully refresh access token even if it is not expired. force_refresh: forcefully refresh access token even if it is not expired.
allow_user_interaction: True to enable blocking for user input if needed. allow_user_interaction: True to enable blocking for user input if needed.
use_local_auth: default to local auth if needed.
Raises: Raises:
AuthenticationError on error or if authentication flow was interrupted. AuthenticationError on error or if authentication flow was interrupted.
LoginRequiredError if user interaction is required, but LoginRequiredError if user interaction is required, but
allow_user_interaction is False. allow_user_interaction is False.
""" """
def get_loc_auth_tkn():
exi = sys.exc_info()
if not use_local_auth:
logging.error('Failed to create access token')
raise
try:
self._access_token = get_luci_context_access_token()
if not self._access_token:
logging.error('Failed to create access token')
raise
return self._access_token
except LuciContextAuthError:
logging.exception('Failed to use local auth')
raise exi[0], exi[1], exi[2]
with self._lock: with self._lock:
if force_refresh: if force_refresh:
logging.debug('Forcing access token refresh') logging.debug('Forcing access token refresh')
try:
self._access_token = self._create_access_token(allow_user_interaction) self._access_token = self._create_access_token(allow_user_interaction)
return self._access_token return self._access_token
except LoginRequiredError:
return get_loc_auth_tkn()
# Load from on-disk cache on a first access. # Load from on-disk cache on a first access.
if not self._access_token: if not self._access_token:
...@@ -331,7 +468,11 @@ class Authenticator(object): ...@@ -331,7 +468,11 @@ class Authenticator(object):
self._access_token = self._load_access_token() self._access_token = self._load_access_token()
# Nope, still expired, need to run the refresh flow. # Nope, still expired, need to run the refresh flow.
if not self._access_token or _needs_refresh(self._access_token): if not self._access_token or _needs_refresh(self._access_token):
self._access_token = self._create_access_token(allow_user_interaction) try:
self._access_token = self._create_access_token(
allow_user_interaction)
except LoginRequiredError:
get_loc_auth_tkn()
return self._access_token return self._access_token
...@@ -548,11 +689,12 @@ def _read_refresh_token_json(path): ...@@ -548,11 +689,12 @@ def _read_refresh_token_json(path):
'Failed to read refresh token from %s: missing key %s' % (path, e)) 'Failed to read refresh token from %s: missing key %s' % (path, e))
def _needs_refresh(access_token): def _needs_refresh(access_token, now=None):
"""True if AccessToken should be refreshed.""" """True if AccessToken should be refreshed."""
if access_token.expires_at is not None: if access_token.expires_at is not None:
now = now or datetime.datetime.utcnow()
# Allow 5 min of clock skew between client and backend. # Allow 5 min of clock skew between client and backend.
now = datetime.datetime.utcnow() + datetime.timedelta(seconds=300) now += datetime.timedelta(seconds=300)
return now >= access_token.expires_at return now >= access_token.expires_at
# Token without expiration time never expires. # Token without expiration time never expires.
return False return False
......
...@@ -71,7 +71,7 @@ def CheckChangedConfigs(input_api, output_api): ...@@ -71,7 +71,7 @@ def CheckChangedConfigs(input_api, output_api):
try: try:
authenticator = auth.get_authenticator_for_host( authenticator = auth.get_authenticator_for_host(
LUCI_CONFIG_HOST_NAME, auth.make_auth_config()) LUCI_CONFIG_HOST_NAME, auth.make_auth_config())
acc_tkn = authenticator.get_access_token(allow_user_interaction=True).token acc_tkn = authenticator.get_access_token()
except auth.AuthenticationError as e: except auth.AuthenticationError as e:
return [output_api.PresubmitError( return [output_api.PresubmitError(
'Error in authenticating user.', long_text=str(e))] 'Error in authenticating user.', long_text=str(e))]
...@@ -80,7 +80,7 @@ def CheckChangedConfigs(input_api, output_api): ...@@ -80,7 +80,7 @@ def CheckChangedConfigs(input_api, output_api):
api_url = ('https://%s/_ah/api/config/v1/%s' api_url = ('https://%s/_ah/api/config/v1/%s'
% (LUCI_CONFIG_HOST_NAME, endpoint)) % (LUCI_CONFIG_HOST_NAME, endpoint))
req = urllib2.Request(api_url) req = urllib2.Request(api_url)
req.add_header('Authorization', 'Bearer %s' % acc_tkn) req.add_header('Authorization', 'Bearer %s' % acc_tkn.token)
if body is not None: if body is not None:
req.add_header('Content-Type', 'application/json') req.add_header('Content-Type', 'application/json')
req.add_data(json.dumps(body)) req.add_data(json.dumps(body))
......
#!/usr/bin/env python
# Copyright (c) 2017 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Unit Tests for auth.py"""
import __builtin__
import datetime
import json
import logging
import os
import unittest
import sys
import time
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from testing_support import auto_stub
from third_party import httplib2
from third_party import mock
import auth
class TestGetLuciContextAccessToken(auto_stub.TestCase):
mock_env = {'LUCI_CONTEXT': 'default/test/path'}
def _mock_local_auth(self, account_id, secret, rpc_port):
self.mock(auth, '_load_luci_context', mock.Mock())
auth._load_luci_context.return_value = {
'local_auth': {
'default_account_id': account_id,
'secret': secret,
'rpc_port': rpc_port,
}
}
def _mock_loc_server_resp(self, status, content):
mock_resp = mock.Mock()
mock_resp.status = status
self.mock(httplib2.Http, 'request', mock.Mock())
httplib2.Http.request.return_value = (mock_resp, content)
def test_correct_local_auth_format(self):
self._mock_local_auth('dead', 'beef', 10)
expiry_time = datetime.datetime.min + datetime.timedelta(minutes=60)
resp_content = {
'error_code': None,
'error_message': None,
'access_token': 'token',
'expiry': time.mktime(expiry_time.timetuple()),
}
self._mock_loc_server_resp(200, json.dumps(resp_content))
token = auth._get_luci_context_access_token(
self.mock_env, datetime.datetime.min)
self.assertEquals(token.token, 'token')
def test_incorrect_port_format(self):
self._mock_local_auth('foo', 'bar', 'bar')
with self.assertRaises(auth.LuciContextAuthError):
auth._get_luci_context_access_token(self.mock_env, datetime.datetime.min)
def test_no_account_id(self):
self._mock_local_auth(None, 'bar', 10)
token = auth._get_luci_context_access_token(
self.mock_env, datetime.datetime.min)
self.assertIsNone(token)
def test_expired_token(self):
self._mock_local_auth('dead', 'beef', 10)
resp_content = {
'error_code': None,
'error_message': None,
'access_token': 'token',
'expiry': 1,
}
self._mock_loc_server_resp(200, json.dumps(resp_content))
with self.assertRaises(auth.LuciContextAuthError):
auth._get_luci_context_access_token(
self.mock_env, datetime.datetime.utcfromtimestamp(1))
def test_incorrect_expiry_format(self):
self._mock_local_auth('dead', 'beef', 10)
resp_content = {
'error_code': None,
'error_message': None,
'access_token': 'token',
'expiry': 'dead',
}
self._mock_loc_server_resp(200, json.dumps(resp_content))
with self.assertRaises(auth.LuciContextAuthError):
auth._get_luci_context_access_token(self.mock_env, datetime.datetime.min)
def test_incorrect_response_content_format(self):
self._mock_local_auth('dead', 'beef', 10)
self._mock_loc_server_resp(200, '5')
with self.assertRaises(auth.LuciContextAuthError):
auth._get_luci_context_access_token(self.mock_env, datetime.datetime.min)
if __name__ == '__main__':
if '-v' in sys.argv:
logging.basicConfig(level=logging.DEBUG)
unittest.main()
...@@ -1974,8 +1974,7 @@ class CannedChecksUnittest(PresubmitTestsBase): ...@@ -1974,8 +1974,7 @@ class CannedChecksUnittest(PresubmitTestsBase):
token_mock = self.mox.CreateMock(auth.AccessToken) token_mock = self.mox.CreateMock(auth.AccessToken)
token_mock.token = 123 token_mock.token = 123
auth_mock = self.mox.CreateMock(auth.Authenticator) auth_mock = self.mox.CreateMock(auth.Authenticator)
auth_mock.get_access_token( auth_mock.get_access_token().AndReturn(token_mock)
allow_user_interaction=True).AndReturn(token_mock)
self.mox.StubOutWithMock(auth, 'get_authenticator_for_host') self.mox.StubOutWithMock(auth, 'get_authenticator_for_host')
auth.get_authenticator_for_host( auth.get_authenticator_for_host(
mox.IgnoreArg(), mox.IgnoreArg()).AndReturn(auth_mock) mox.IgnoreArg(), mox.IgnoreArg()).AndReturn(auth_mock)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment