Commit 5b929a49 authored by Edward Lemur's avatar Edward Lemur Committed by Commit Bot

depot_tools: Clean up auth.py

Bug: 1001756
Change-Id: I495991c30f7f00de17e7c098e2d88ef7783aff3c
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/1865265
Auto-Submit: Edward Lesmes <ehmaldonado@chromium.org>
Reviewed-by: 's avatarVadim Shtayura <vadimsh@chromium.org>
Commit-Queue: Edward Lesmes <ehmaldonado@chromium.org>
parent a0aed87f
...@@ -11,21 +11,13 @@ import datetime ...@@ -11,21 +11,13 @@ import datetime
import functools import functools
import json import json
import logging import logging
import optparse
import os import os
import sys
import threading
import urllib
import urlparse
import subprocess2 import subprocess2
from third_party import httplib2 from third_party import httplib2
# depot_tools/.
DEPOT_TOOLS_DIR = os.path.dirname(os.path.abspath(__file__))
# This is what most GAE apps require for authentication. # This is what most GAE apps require for authentication.
OAUTH_SCOPE_EMAIL = 'https://www.googleapis.com/auth/userinfo.email' OAUTH_SCOPE_EMAIL = 'https://www.googleapis.com/auth/userinfo.email'
# Gerrit and Git on *.googlesource.com require this scope. # Gerrit and Git on *.googlesource.com require this scope.
...@@ -39,38 +31,22 @@ def datetime_now(): ...@@ -39,38 +31,22 @@ def datetime_now():
return datetime.datetime.utcnow() return datetime.datetime.utcnow()
# Authentication configuration extracted from command line options.
# See doc string for 'make_auth_config' for meaning of fields.
AuthConfig = collections.namedtuple('AuthConfig', [
'use_oauth2', # deprecated, will be always True
'save_cookies', # deprecated, will be removed
'use_local_webserver',
'webserver_port',
])
# OAuth access token with its expiration time (UTC datetime or None if unknown). # OAuth access token with its expiration time (UTC datetime or None if unknown).
class AccessToken(collections.namedtuple('AccessToken', [ class AccessToken(collections.namedtuple('AccessToken', [
'token', 'token',
'expires_at', 'expires_at',
])): ])):
def needs_refresh(self, now=None): def needs_refresh(self):
"""True if this AccessToken should be refreshed.""" """True if this AccessToken should be refreshed."""
if self.expires_at is not None: if self.expires_at is not None:
now = now or datetime_now()
# Allow 30s of clock skew between client and backend. # Allow 30s of clock skew between client and backend.
now += datetime.timedelta(seconds=30) return datetime_now() + datetime.timedelta(seconds=30) >= self.expires_at
return now >= self.expires_at
# Token without expiration time never expires. # Token without expiration time never expires.
return False return False
class AuthenticationError(Exception): class LoginRequiredError(Exception):
"""Raised on errors related to authentication."""
class LoginRequiredError(AuthenticationError):
"""Interaction with the user is required to authenticate.""" """Interaction with the user is required to authenticate."""
def __init__(self, scopes=OAUTH_SCOPE_EMAIL): def __init__(self, scopes=OAUTH_SCOPE_EMAIL):
...@@ -80,342 +56,50 @@ class LoginRequiredError(AuthenticationError): ...@@ -80,342 +56,50 @@ class LoginRequiredError(AuthenticationError):
super(LoginRequiredError, self).__init__(msg) super(LoginRequiredError, self).__init__(msg)
class LuciContextAuthError(Exception):
"""Raised on errors related to unsuccessful attempts to load LUCI_CONTEXT"""
def __init__(self, msg, exc=None):
if exc is None:
logging.error(msg)
else:
logging.exception(msg)
msg = '%s: %s' % (msg, exc)
super(LuciContextAuthError, self).__init__(msg)
def has_luci_context_local_auth(): def has_luci_context_local_auth():
"""Returns whether LUCI_CONTEXT should be used for ambient authentication. """Returns whether LUCI_CONTEXT should be used for ambient authentication."""
""" return bool(os.environ.get('LUCI_CONTEXT'))
try:
params = _get_luci_context_local_auth_params()
except LuciContextAuthError:
return False
if params is None:
return False
return bool(params.default_account_id)
# TODO(crbug.com/1001756): Remove. luci-auth uses local auth if available,
# making this unnecessary.
def get_luci_context_access_token(scopes=OAUTH_SCOPE_EMAIL):
"""Returns a valid AccessToken from the local LUCI context auth server.
Adapted from
https://chromium.googlesource.com/infra/luci/luci-py/+/master/client/libs/luci_context/luci_context.py
See the link above for more details.
Returns:
AccessToken if LUCI_CONTEXT is present and attempt to load it is successful.
None if LUCI_CONTEXT is absent.
Raises:
LuciContextAuthError if LUCI_CONTEXT is present, but there was a failure
obtaining its access token.
"""
params = _get_luci_context_local_auth_params()
if params is None:
return None
return _get_luci_context_access_token(
params, datetime.datetime.utcnow(), scopes)
_LuciContextLocalAuthParams = collections.namedtuple(
'_LuciContextLocalAuthParams', [
'default_account_id',
'secret',
'rpc_port',
])
def _cache_thread_safe(f):
"""Decorator caching result of nullary function in thread-safe way."""
lock = threading.Lock()
cache = []
@functools.wraps(f)
def caching_wrapper():
if not cache:
with lock:
if not cache:
cache.append(f())
return cache[0]
# Allow easy way to clear cache, particularly useful in tests.
caching_wrapper.clear_cache = lambda: cache.pop() if cache else None
return caching_wrapper
@_cache_thread_safe
def _get_luci_context_local_auth_params():
"""Returns local auth parameters if local auth is configured else None.
Raises LuciContextAuthError on unexpected failures.
"""
ctx_path = os.environ.get('LUCI_CONTEXT')
if not ctx_path:
return None
ctx_path = ctx_path.decode(sys.getfilesystemencoding())
try:
loaded = _load_luci_context(ctx_path)
except (OSError, IOError, ValueError) as e:
raise LuciContextAuthError('Failed to open, read or decode LUCI_CONTEXT', e)
try:
local_auth = loaded.get('local_auth')
except AttributeError as e:
raise LuciContextAuthError('LUCI_CONTEXT not in proper format', e)
if local_auth is None:
logging.debug('LUCI_CONTEXT configured w/o local auth')
return None
try:
return _LuciContextLocalAuthParams(
default_account_id=local_auth.get('default_account_id'),
secret=local_auth.get('secret'),
rpc_port=int(local_auth.get('rpc_port')))
except (AttributeError, ValueError) as e:
raise LuciContextAuthError('local_auth config malformed', e)
def _load_luci_context(ctx_path):
# Kept separate for test mocking.
with open(ctx_path) as f:
return json.load(f)
def _get_luci_context_access_token(params, now, scopes=OAUTH_SCOPE_EMAIL):
# No account, local_auth shouldn't be used.
if not params.default_account_id:
return None
if not params.secret:
raise LuciContextAuthError('local_auth: no secret')
logging.debug('local_auth: requesting an access token for account "%s"',
params.default_account_id)
http = httplib2.Http()
host = '127.0.0.1:%d' % params.rpc_port
resp, content = http.request(
uri='http://%s/rpc/LuciLocalAuthService.GetOAuthToken' % host,
method='POST',
body=json.dumps({
'account_id': params.default_account_id,
'scopes': scopes.split(' '),
'secret': params.secret,
}),
headers={'Content-Type': 'application/json'})
if resp.status != 200:
raise LuciContextAuthError(
'local_auth: Failed to grab access token from '
'LUCI context server with status %d: %r' % (resp.status, content))
try:
token = json.loads(content)
error_code = token.get('error_code')
error_message = token.get('error_message')
access_token = token.get('access_token')
expiry = token.get('expiry')
except (AttributeError, ValueError) as e:
raise LuciContextAuthError('Unexpected access token response format', e)
if error_code:
raise LuciContextAuthError(
'Error %d in retrieving access token: %s', error_code, error_message)
if not access_token:
raise LuciContextAuthError(
'No access token returned from LUCI context server')
expiry_dt = None
if expiry:
try:
expiry_dt = datetime.datetime.utcfromtimestamp(expiry)
logging.debug(
'local_auth: got an access token for '
'account "%s" that expires in %d sec',
params.default_account_id, (expiry_dt - now).total_seconds())
except (TypeError, ValueError) as e:
raise LuciContextAuthError('Invalid expiry in returned token', e)
else:
logging.debug(
'local auth: got an access token for account "%s" that does not expire',
params.default_account_id)
access_token = AccessToken(access_token, expiry_dt)
if access_token.needs_refresh(now=now):
raise LuciContextAuthError('Received access token is already expired')
return access_token
def make_auth_config(
use_oauth2=None,
save_cookies=None,
use_local_webserver=None,
webserver_port=None):
"""Returns new instance of AuthConfig.
If some config option is None, it will be set to a reasonable default value.
This function also acts as an authoritative place for default values of
corresponding command line options.
"""
default = lambda val, d: val if val is not None else d
return AuthConfig(
default(use_oauth2, True),
default(save_cookies, True),
default(use_local_webserver, not _is_headless()),
default(webserver_port, 8090))
def add_auth_options(parser, default_config=None):
"""Appends OAuth related options to OptionParser."""
default_config = default_config or make_auth_config()
parser.auth_group = optparse.OptionGroup(parser, 'Auth options')
parser.add_option_group(parser.auth_group)
# OAuth2 vs password switch.
auth_default = 'use OAuth2' if default_config.use_oauth2 else 'use password'
parser.auth_group.add_option(
'--oauth2',
action='store_true',
dest='use_oauth2',
default=default_config.use_oauth2,
help='Use OAuth 2.0 instead of a password. [default: %s]' % auth_default)
parser.auth_group.add_option(
'--no-oauth2',
action='store_false',
dest='use_oauth2',
default=default_config.use_oauth2,
help='Use password instead of OAuth 2.0. [default: %s]' % auth_default)
# Password related options, deprecated.
parser.auth_group.add_option(
'--no-cookies',
action='store_false',
dest='save_cookies',
default=default_config.save_cookies,
help='Do not save authentication cookies to local disk.')
# OAuth2 related options.
# TODO(crbug.com/1001756): Remove. No longer supported.
parser.auth_group.add_option(
'--auth-no-local-webserver',
action='store_false',
dest='use_local_webserver',
default=default_config.use_local_webserver,
help='DEPRECATED. Do not use')
parser.auth_group.add_option(
'--auth-host-port',
type=int,
default=default_config.webserver_port,
help='DEPRECATED. Do not use')
parser.auth_group.add_option(
'--auth-refresh-token-json',
help='DEPRECATED. Do not use')
def extract_auth_config_from_options(options):
"""Given OptionParser parsed options, extracts AuthConfig from it.
OptionParser should be populated with auth options by 'add_auth_options'.
"""
return make_auth_config(
use_oauth2=options.use_oauth2,
save_cookies=False if options.use_oauth2 else options.save_cookies,
use_local_webserver=options.use_local_webserver,
webserver_port=options.auth_host_port)
def auth_config_to_command_options(auth_config):
"""AuthConfig -> list of strings with command line options.
Omits options that are set to default values.
"""
if not auth_config:
return []
defaults = make_auth_config()
opts = []
if auth_config.use_oauth2 != defaults.use_oauth2:
opts.append('--oauth2' if auth_config.use_oauth2 else '--no-oauth2')
if auth_config.save_cookies != auth_config.save_cookies:
if not auth_config.save_cookies:
opts.append('--no-cookies')
if auth_config.use_local_webserver != defaults.use_local_webserver:
if not auth_config.use_local_webserver:
opts.append('--auth-no-local-webserver')
if auth_config.webserver_port != defaults.webserver_port:
opts.extend(['--auth-host-port', str(auth_config.webserver_port)])
return opts
def get_authenticator(config, scopes=OAUTH_SCOPE_EMAIL):
"""Returns Authenticator instance to access given host.
Args:
config: AuthConfig instance.
scopes: space separated oauth scopes. Defaults to OAUTH_SCOPE_EMAIL.
Returns:
Authenticator object.
"""
return Authenticator(config, scopes)
class Authenticator(object): class Authenticator(object):
"""Object that knows how to refresh access tokens when needed. """Object that knows how to refresh access tokens when needed.
Args: Args:
config: AuthConfig object that holds authentication configuration. scopes: space separated oauth scopes. Defaults to OAUTH_SCOPE_EMAIL.
""" """
def __init__(self, config, scopes): def __init__(self, scopes=OAUTH_SCOPE_EMAIL):
assert isinstance(config, AuthConfig)
assert config.use_oauth2
self._access_token = None self._access_token = None
self._config = config
self._lock = threading.Lock()
self._scopes = scopes self._scopes = scopes
logging.debug('Using auth config %r', config)
def has_cached_credentials(self): def has_cached_credentials(self):
"""Returns True if credentials can be obtained. """Returns True if credentials can be obtained.
If returns False, get_access_token() later will probably ask for interactive If returns False, get_access_token() later will probably ask for interactive
login by raising LoginRequiredError, unless local auth is configured. login by raising LoginRequiredError.
If returns True, get_access_token() won't ask for interactive login. If returns True, get_access_token() won't ask for interactive login.
""" """
with self._lock: return bool(self._get_luci_auth_token())
return bool(self._get_luci_auth_token())
def get_access_token(self, force_refresh=False, allow_user_interaction=False, def get_access_token(self):
use_local_auth=True):
"""Returns AccessToken, refreshing it if necessary. """Returns AccessToken, refreshing it if necessary.
Args:
TODO(crbug.com/1001756): Remove.
force_refresh: Ignored, luci-auth doesn't support force-refreshing tokens.
allow_user_interaction: Ignored. allow_user_interaction is always False.
use_local_auth: Ignored. luci-auth already covers local_auth.
Raises: Raises:
AuthenticationError on error or if authentication flow was interrupted. LoginRequiredError if user interaction is required.
LoginRequiredError if user interaction is required, but
allow_user_interaction is False.
""" """
with self._lock: if self._access_token and not self._access_token.needs_refresh():
if self._access_token and not self._access_token.needs_refresh(): return self._access_token
return self._access_token
# Token expired or missing. Maybe some other process already updated it, # Token expired or missing. Maybe some other process already updated it,
# reload from the cache. # reload from the cache.
self._access_token = self._get_luci_auth_token() self._access_token = self._get_luci_auth_token()
if self._access_token and not self._access_token.needs_refresh(): if self._access_token and not self._access_token.needs_refresh():
return self._access_token return self._access_token
# Nope, still expired, need to run the refresh flow. # Nope, still expired. Needs user interaction.
logging.error('Failed to create access token') logging.error('Failed to create access token')
raise LoginRequiredError(self._scopes) raise LoginRequiredError(self._scopes)
def authorize(self, http): def authorize(self, http):
"""Monkey patches authentication logic of httplib2.Http instance. """Monkey patches authentication logic of httplib2.Http instance.
...@@ -470,11 +154,3 @@ class Authenticator(object): ...@@ -470,11 +154,3 @@ class Authenticator(object):
datetime.datetime.utcfromtimestamp(token_info['expiry'])) datetime.datetime.utcfromtimestamp(token_info['expiry']))
except subprocess2.CalledProcessError: except subprocess2.CalledProcessError:
return None return None
## Private functions.
def _is_headless():
"""True if machine doesn't seem to have a display."""
return sys.platform == 'linux2' and not os.environ.get('DISPLAY')
...@@ -351,17 +351,11 @@ class LuciContextAuthenticator(Authenticator): ...@@ -351,17 +351,11 @@ class LuciContextAuthenticator(Authenticator):
return auth.has_luci_context_local_auth() return auth.has_luci_context_local_auth()
def __init__(self): def __init__(self):
self._access_token = None self._authenticator = auth.Authenticator(
self._ensure_fresh() ' '.join([auth.OAUTH_SCOPE_EMAIL, auth.OAUTH_SCOPE_GERRIT]))
def _ensure_fresh(self):
if not self._access_token or self._access_token.needs_refresh():
self._access_token = auth.get_luci_context_access_token(
scopes=' '.join([auth.OAUTH_SCOPE_EMAIL, auth.OAUTH_SCOPE_GERRIT]))
def get_auth_header(self, _host): def get_auth_header(self, _host):
self._ensure_fresh() return 'Bearer %s' % self._authenticator.get_access_token().token
return 'Bearer %s' % self._access_token.token
def CreateHttpConn(host, path, reqtype='GET', headers=None, body=None): def CreateHttpConn(host, path, reqtype='GET', headers=None, body=None):
......
...@@ -443,11 +443,10 @@ def _parse_bucket(raw_bucket): ...@@ -443,11 +443,10 @@ def _parse_bucket(raw_bucket):
return project, bucket return project, bucket
def _trigger_try_jobs(auth_config, changelist, buckets, options, patchset): def _trigger_try_jobs(changelist, buckets, options, patchset):
"""Sends a request to Buildbucket to trigger tryjobs for a changelist. """Sends a request to Buildbucket to trigger tryjobs for a changelist.
Args: Args:
auth_config: AuthConfig for Buildbucket.
changelist: Changelist that the tryjobs are associated with. changelist: Changelist that the tryjobs are associated with.
buckets: A nested dict mapping bucket names to builders to tests. buckets: A nested dict mapping bucket names to builders to tests.
options: Command-line options. options: Command-line options.
...@@ -466,7 +465,7 @@ def _trigger_try_jobs(auth_config, changelist, buckets, options, patchset): ...@@ -466,7 +465,7 @@ def _trigger_try_jobs(auth_config, changelist, buckets, options, patchset):
if not requests: if not requests:
return return
http = auth.get_authenticator(auth_config).authorize(httplib2.Http()) http = auth.Authenticator().authorize(httplib2.Http())
http.force_exception_to_status_code = True http.force_exception_to_status_code = True
batch_request = {'requests': requests} batch_request = {'requests': requests}
...@@ -527,8 +526,7 @@ def _make_try_job_schedule_requests(changelist, buckets, options, patchset): ...@@ -527,8 +526,7 @@ def _make_try_job_schedule_requests(changelist, buckets, options, patchset):
return requests return requests
def fetch_try_jobs(auth_config, changelist, buildbucket_host, def fetch_try_jobs(changelist, buildbucket_host, patchset=None):
patchset=None):
"""Fetches tryjobs from buildbucket. """Fetches tryjobs from buildbucket.
Returns list of buildbucket.v2.Build with the try jobs for the changelist. Returns list of buildbucket.v2.Build with the try jobs for the changelist.
...@@ -541,7 +539,7 @@ def fetch_try_jobs(auth_config, changelist, buildbucket_host, ...@@ -541,7 +539,7 @@ def fetch_try_jobs(auth_config, changelist, buildbucket_host,
'fields': ','.join('builds.*.' + field for field in fields), 'fields': ','.join('builds.*.' + field for field in fields),
} }
authenticator = auth.get_authenticator(auth_config) authenticator = auth.Authenticator()
if authenticator.has_cached_credentials(): if authenticator.has_cached_credentials():
http = authenticator.authorize(httplib2.Http()) http = authenticator.authorize(httplib2.Http())
else: else:
...@@ -554,13 +552,11 @@ def fetch_try_jobs(auth_config, changelist, buildbucket_host, ...@@ -554,13 +552,11 @@ def fetch_try_jobs(auth_config, changelist, buildbucket_host,
response = _call_buildbucket(http, buildbucket_host, 'SearchBuilds', request) response = _call_buildbucket(http, buildbucket_host, 'SearchBuilds', request)
return response.get('builds', []) return response.get('builds', [])
def _fetch_latest_builds( def _fetch_latest_builds(changelist, buildbucket_host, latest_patchset=None):
auth_config, changelist, buildbucket_host, latest_patchset=None):
"""Fetches builds from the latest patchset that has builds (within """Fetches builds from the latest patchset that has builds (within
the last few patchsets). the last few patchsets).
Args: Args:
auth_config (auth.AuthConfig): Auth info for Buildbucket
changelist (Changelist): The CL to fetch builds for changelist (Changelist): The CL to fetch builds for
buildbucket_host (str): Buildbucket host, e.g. "cr-buildbucket.appspot.com" buildbucket_host (str): Buildbucket host, e.g. "cr-buildbucket.appspot.com"
lastest_patchset(int|NoneType): the patchset to start fetching builds from. lastest_patchset(int|NoneType): the patchset to start fetching builds from.
...@@ -581,8 +577,7 @@ def _fetch_latest_builds( ...@@ -581,8 +577,7 @@ def _fetch_latest_builds(
min_ps = max(1, ps - 5) min_ps = max(1, ps - 5)
while ps >= min_ps: while ps >= min_ps:
builds = fetch_try_jobs( builds = fetch_try_jobs(changelist, buildbucket_host, patchset=ps)
auth_config, changelist, buildbucket_host, patchset=ps)
if len(builds): if len(builds):
return builds, ps return builds, ps
ps -= 1 ps -= 1
...@@ -4439,7 +4434,6 @@ def CMDupload(parser, args): ...@@ -4439,7 +4434,6 @@ def CMDupload(parser, args):
'fixed (pre-populates "Fixed:" tag). Same format as ' 'fixed (pre-populates "Fixed:" tag). Same format as '
'-b option / "Bug:" tag. If fixing several issues, ' '-b option / "Bug:" tag. If fixing several issues, '
'separate with commas.') 'separate with commas.')
auth.add_auth_options(parser)
orig_args = args orig_args = args
_add_codereview_select_options(parser) _add_codereview_select_options(parser)
...@@ -4487,15 +4481,13 @@ def CMDupload(parser, args): ...@@ -4487,15 +4481,13 @@ def CMDupload(parser, args):
if ret != 0: if ret != 0:
print('Upload failed, so --retry-failed has no effect.') print('Upload failed, so --retry-failed has no effect.')
return ret return ret
auth_config = auth.extract_auth_config_from_options(options)
builds, _ = _fetch_latest_builds( builds, _ = _fetch_latest_builds(
auth_config, cl, options.buildbucket_host, cl, options.buildbucket_host, latest_patchset=patchset)
latest_patchset=patchset)
buckets = _filter_failed_for_retry(builds) buckets = _filter_failed_for_retry(builds)
if len(buckets) == 0: if len(buckets) == 0:
print('No failed tryjobs, so --retry-failed has no effect.') print('No failed tryjobs, so --retry-failed has no effect.')
return ret return ret
_trigger_try_jobs(auth_config, cl, buckets, options, patchset + 1) _trigger_try_jobs(cl, buckets, options, patchset + 1)
return ret return ret
...@@ -4745,10 +4737,8 @@ def CMDtry(parser, args): ...@@ -4745,10 +4737,8 @@ def CMDtry(parser, args):
'-R', '--retry-failed', action='store_true', default=False, '-R', '--retry-failed', action='store_true', default=False,
help='Retry failed jobs from the latest set of tryjobs. ' help='Retry failed jobs from the latest set of tryjobs. '
'Not allowed with --bucket and --bot options.') 'Not allowed with --bucket and --bot options.')
auth.add_auth_options(parser)
_add_codereview_issue_select_options(parser) _add_codereview_issue_select_options(parser)
options, args = parser.parse_args(args) options, args = parser.parse_args(args)
auth_config = auth.extract_auth_config_from_options(options)
# Make sure that all properties are prop=value pairs. # Make sure that all properties are prop=value pairs.
bad_params = [x for x in options.properties if '=' not in x] bad_params = [x for x in options.properties if '=' not in x]
...@@ -4775,8 +4765,7 @@ def CMDtry(parser, args): ...@@ -4775,8 +4765,7 @@ def CMDtry(parser, args):
'-B, -b, --bucket, or --bot.', file=sys.stderr) '-B, -b, --bucket, or --bot.', file=sys.stderr)
return 1 return 1
print('Searching for failed tryjobs...') print('Searching for failed tryjobs...')
builds, patchset = _fetch_latest_builds( builds, patchset = _fetch_latest_builds(cl, options.buildbucket_host)
auth_config, cl, options.buildbucket_host)
if options.verbose: if options.verbose:
print('Got %d builds in patchset #%d' % (len(builds), patchset)) print('Got %d builds in patchset #%d' % (len(builds), patchset))
buckets = _filter_failed_for_retry(builds) buckets = _filter_failed_for_retry(builds)
...@@ -4812,7 +4801,7 @@ def CMDtry(parser, args): ...@@ -4812,7 +4801,7 @@ def CMDtry(parser, args):
patchset = cl.GetMostRecentPatchset() patchset = cl.GetMostRecentPatchset()
try: try:
_trigger_try_jobs(auth_config, cl, buckets, options, patchset) _trigger_try_jobs(cl, buckets, options, patchset)
except BuildbucketResponseException as ex: except BuildbucketResponseException as ex:
print('ERROR: %s' % ex) print('ERROR: %s' % ex)
return 1 return 1
...@@ -4837,13 +4826,11 @@ def CMDtry_results(parser, args): ...@@ -4837,13 +4826,11 @@ def CMDtry_results(parser, args):
'--json', help=('Path of JSON output file to write tryjob results to,' '--json', help=('Path of JSON output file to write tryjob results to,'
'or "-" for stdout.')) 'or "-" for stdout.'))
parser.add_option_group(group) parser.add_option_group(group)
auth.add_auth_options(parser)
_add_codereview_issue_select_options(parser) _add_codereview_issue_select_options(parser)
options, args = parser.parse_args(args) options, args = parser.parse_args(args)
if args: if args:
parser.error('Unrecognized args: %s' % ' '.join(args)) parser.error('Unrecognized args: %s' % ' '.join(args))
auth_config = auth.extract_auth_config_from_options(options)
cl = Changelist(issue=options.issue) cl = Changelist(issue=options.issue)
if not cl.GetIssue(): if not cl.GetIssue():
parser.error('Need to upload first.') parser.error('Need to upload first.')
...@@ -4858,7 +4845,7 @@ def CMDtry_results(parser, args): ...@@ -4858,7 +4845,7 @@ def CMDtry_results(parser, args):
cl.GetIssue()) cl.GetIssue())
try: try:
jobs = fetch_try_jobs(auth_config, cl, options.buildbucket_host, patchset) jobs = fetch_try_jobs(cl, options.buildbucket_host, patchset)
except BuildbucketResponseException as ex: except BuildbucketResponseException as ex:
print('Buildbucket error: %s' % ex) print('Buildbucket error: %s' % ex)
return 1 return 1
...@@ -5470,7 +5457,7 @@ def main(argv): ...@@ -5470,7 +5457,7 @@ def main(argv):
dispatcher = subcommand.CommandDispatcher(__name__) dispatcher = subcommand.CommandDispatcher(__name__)
try: try:
return dispatcher.execute(OptionParser(), argv) return dispatcher.execute(OptionParser(), argv)
except auth.AuthenticationError as e: except auth.LoginRequiredError as e:
DieWithError(str(e)) DieWithError(str(e))
except urllib2.HTTPError as e: except urllib2.HTTPError as e:
if e.code != 500: if e.code != 500:
......
...@@ -292,12 +292,10 @@ class MyActivity(object): ...@@ -292,12 +292,10 @@ class MyActivity(object):
return ret return ret
def monorail_get_auth_http(self): def monorail_get_auth_http(self):
auth_config = auth.extract_auth_config_from_options(self.options)
authenticator = auth.get_authenticator(auth_config)
# Manually use a long timeout (10m); for some users who have a # Manually use a long timeout (10m); for some users who have a
# long history on the issue tracker, whatever the default timeout # long history on the issue tracker, whatever the default timeout
# is is reached. # is is reached.
return authenticator.authorize(httplib2.Http(timeout=600)) return auth.Authenticator().authorize(httplib2.Http(timeout=600))
def filter_modified_monorail_issue(self, issue): def filter_modified_monorail_issue(self, issue):
"""Precisely checks if an issue has been modified in the time range. """Precisely checks if an issue has been modified in the time range.
...@@ -809,7 +807,6 @@ def main(): ...@@ -809,7 +807,6 @@ def main():
'-j', '--json', action='store_true', '-j', '--json', action='store_true',
help='Output json data (overrides other format options)') help='Output json data (overrides other format options)')
parser.add_option_group(output_format_group) parser.add_option_group(output_format_group)
auth.add_auth_options(parser)
parser.add_option( parser.add_option(
'-v', '--verbose', '-v', '--verbose',
...@@ -925,8 +922,8 @@ def main(): ...@@ -925,8 +922,8 @@ def main():
my_activity.get_issues() my_activity.get_issues()
if not options.no_referenced_issues: if not options.no_referenced_issues:
my_activity.get_referenced_issues() my_activity.get_referenced_issues()
except auth.AuthenticationError as e: except auth.LoginRequiredError as e:
logging.error('auth.AuthenticationError: %s', e) logging.error('auth.LoginRequiredError: %s', e)
my_activity.show_progress('\n') my_activity.show_progress('\n')
......
...@@ -1417,9 +1417,8 @@ def CheckChangedLUCIConfigs(input_api, output_api): ...@@ -1417,9 +1417,8 @@ def CheckChangedLUCIConfigs(input_api, output_api):
# authentication # authentication
try: try:
authenticator = auth.get_authenticator(auth.make_auth_config()) acc_tkn = auth.Authenticator().get_access_token()
acc_tkn = authenticator.get_access_token() except auth.LoginRequiredError as e:
except auth.AuthenticationError as e:
return [output_api.PresubmitError( return [output_api.PresubmitError(
'Error in authenticating user.', long_text=str(e))] 'Error in authenticating user.', long_text=str(e))]
......
...@@ -34,42 +34,57 @@ class AuthenticatorTest(unittest.TestCase): ...@@ -34,42 +34,57 @@ class AuthenticatorTest(unittest.TestCase):
def testHasCachedCredentials_NotLoggedIn(self): def testHasCachedCredentials_NotLoggedIn(self):
subprocess2.check_call_out.side_effect = [ subprocess2.check_call_out.side_effect = [
subprocess2.CalledProcessError(1, ['cmd'], 'cwd', 'stdout', 'stderr')] subprocess2.CalledProcessError(1, ['cmd'], 'cwd', 'stdout', 'stderr')]
authenticator = auth.get_authenticator(auth.make_auth_config()) self.assertFalse(auth.Authenticator().has_cached_credentials())
self.assertFalse(authenticator.has_cached_credentials())
def testHasCachedCredentials_LoggedIn(self): def testHasCachedCredentials_LoggedIn(self):
subprocess2.check_call_out.return_value = ( subprocess2.check_call_out.return_value = (
json.dumps({'token': 'token', 'expiry': 12345678}), '') json.dumps({'token': 'token', 'expiry': 12345678}), '')
authenticator = auth.get_authenticator(auth.make_auth_config()) self.assertTrue(auth.Authenticator().has_cached_credentials())
self.assertTrue(authenticator.has_cached_credentials())
def testGetAccessToken_NotLoggedIn(self): def testGetAccessToken_NotLoggedIn(self):
subprocess2.check_call_out.side_effect = [ subprocess2.check_call_out.side_effect = [
subprocess2.CalledProcessError(1, ['cmd'], 'cwd', 'stdout', 'stderr')] subprocess2.CalledProcessError(1, ['cmd'], 'cwd', 'stdout', 'stderr')]
authenticator = auth.get_authenticator(auth.make_auth_config()) self.assertRaises(
self.assertRaises(auth.LoginRequiredError, authenticator.get_access_token) auth.LoginRequiredError, auth.Authenticator().get_access_token)
def testGetAccessToken_CachedToken(self): def testGetAccessToken_CachedToken(self):
authenticator = auth.get_authenticator(auth.make_auth_config()) authenticator = auth.Authenticator()
authenticator._access_token = auth.AccessToken('token', None) authenticator._access_token = auth.AccessToken('token', None)
self.assertEqual( self.assertEqual(
auth.AccessToken('token', None), authenticator.get_access_token()) auth.AccessToken('token', None), authenticator.get_access_token())
self.assertEqual(0, len(subprocess2.check_call_out.mock_calls))
def testGetAccesstoken_LoggedIn(self): def testGetAccesstoken_LoggedIn(self):
expiry = calendar.timegm(VALID_EXPIRY.timetuple()) expiry = calendar.timegm(VALID_EXPIRY.timetuple())
subprocess2.check_call_out.return_value = ( subprocess2.check_call_out.return_value = (
json.dumps({'token': 'token', 'expiry': expiry}), '') json.dumps({'token': 'token', 'expiry': expiry}), '')
authenticator = auth.get_authenticator(auth.make_auth_config())
self.assertEqual( self.assertEqual(
auth.AccessToken('token', VALID_EXPIRY), auth.AccessToken('token', VALID_EXPIRY),
authenticator.get_access_token()) auth.Authenticator().get_access_token())
subprocess2.check_call_out.assert_called_with(
['luci-auth',
'token',
'-scopes', auth.OAUTH_SCOPE_EMAIL,
'-json-output', '-'],
stdout=subprocess2.PIPE, stderr=subprocess2.PIPE)
def testGetAccessToken_DifferentScope(self):
expiry = calendar.timegm(VALID_EXPIRY.timetuple())
subprocess2.check_call_out.return_value = (
json.dumps({'token': 'token', 'expiry': expiry}), '')
self.assertEqual(
auth.AccessToken('token', VALID_EXPIRY),
auth.Authenticator('custom scopes').get_access_token())
subprocess2.check_call_out.assert_called_with(
['luci-auth', 'token', '-scopes', 'custom scopes', '-json-output', '-'],
stdout=subprocess2.PIPE, stderr=subprocess2.PIPE)
def testAuthorize(self): def testAuthorize(self):
http = mock.Mock() http = mock.Mock()
http_request = http.request http_request = http.request
http_request.__name__ = '__name__' http_request.__name__ = '__name__'
authenticator = auth.get_authenticator(auth.make_auth_config()) authenticator = auth.Authenticator()
authenticator._access_token = auth.AccessToken('token', None) authenticator._access_token = auth.AccessToken('token', None)
authorized = authenticator.authorize(http) authorized = authenticator.authorize(http)
......
...@@ -637,7 +637,7 @@ class TestGitCl(TestCase): ...@@ -637,7 +637,7 @@ class TestGitCl(TestCase):
self._mocked_call('write_json', path, contents)) self._mocked_call('write_json', path, contents))
self.mock(git_cl.presubmit_support, 'DoPresubmitChecks', PresubmitMock) self.mock(git_cl.presubmit_support, 'DoPresubmitChecks', PresubmitMock)
self.mock(git_cl.watchlists, 'Watchlists', WatchlistsMock) self.mock(git_cl.watchlists, 'Watchlists', WatchlistsMock)
self.mock(git_cl.auth, 'get_authenticator', AuthenticatorMock) self.mock(git_cl.auth, 'Authenticator', AuthenticatorMock)
self.mock(git_cl.gerrit_util, 'GetChangeDetail', self.mock(git_cl.gerrit_util, 'GetChangeDetail',
lambda *args, **kwargs: self._mocked_call( lambda *args, **kwargs: self._mocked_call(
'GetChangeDetail', *args, **kwargs)) 'GetChangeDetail', *args, **kwargs))
...@@ -3062,7 +3062,7 @@ class CMDTestCaseBase(unittest.TestCase): ...@@ -3062,7 +3062,7 @@ class CMDTestCaseBase(unittest.TestCase):
return_value='https://chromium-review.googlesource.com').start() return_value='https://chromium-review.googlesource.com').start()
mock.patch('git_cl.Changelist.GetMostRecentPatchset', mock.patch('git_cl.Changelist.GetMostRecentPatchset',
return_value=7).start() return_value=7).start()
mock.patch('git_cl.auth.get_authenticator', mock.patch('git_cl.auth.Authenticator',
return_value=AuthenticatorMock()).start() return_value=AuthenticatorMock()).start()
mock.patch('git_cl.Changelist._GetChangeDetail', mock.patch('git_cl.Changelist._GetChangeDetail',
return_value=self._CHANGE_DETAIL).start() return_value=self._CHANGE_DETAIL).start()
...@@ -3382,14 +3382,14 @@ class CMDUploadTestCase(CMDTestCaseBase): ...@@ -3382,14 +3382,14 @@ class CMDUploadTestCase(CMDTestCaseBase):
self.assertEqual(0, git_cl.main(['upload', '--retry-failed'])) self.assertEqual(0, git_cl.main(['upload', '--retry-failed']))
self.assertEqual([ self.assertEqual([
mock.call(mock.ANY, mock.ANY, 'cr-buildbucket.appspot.com', patchset=7), mock.call(mock.ANY, 'cr-buildbucket.appspot.com', patchset=7),
mock.call(mock.ANY, mock.ANY, 'cr-buildbucket.appspot.com', patchset=6), mock.call(mock.ANY, 'cr-buildbucket.appspot.com', patchset=6),
], git_cl.fetch_try_jobs.mock_calls) ], git_cl.fetch_try_jobs.mock_calls)
expected_buckets = { expected_buckets = {
'chromium/try': {'bot_failure': [], 'bot_infra_failure': []}, 'chromium/try': {'bot_failure': [], 'bot_infra_failure': []},
} }
git_cl._trigger_try_jobs.assert_called_once_with( git_cl._trigger_try_jobs.assert_called_once_with(
mock.ANY, mock.ANY, expected_buckets, mock.ANY, 8) mock.ANY, expected_buckets, mock.ANY, 8)
class CMDFormatTestCase(TestCase): class CMDFormatTestCase(TestCase):
......
...@@ -1642,7 +1642,7 @@ class CannedChecksUnittest(PresubmitTestsBase): ...@@ -1642,7 +1642,7 @@ class CannedChecksUnittest(PresubmitTestsBase):
presubmit.OutputApi.PresubmitPromptWarning) presubmit.OutputApi.PresubmitPromptWarning)
@mock.patch('git_cl.Changelist') @mock.patch('git_cl.Changelist')
@mock.patch('auth.get_authenticator') @mock.patch('auth.Authenticator')
def testCannedCheckChangedLUCIConfigs(self, mockGetAuth, mockChangelist): def testCannedCheckChangedLUCIConfigs(self, mockGetAuth, mockChangelist):
affected_file1 = mock.MagicMock(presubmit.GitAffectedFile) affected_file1 = mock.MagicMock(presubmit.GitAffectedFile)
affected_file1.LocalPath.return_value = 'foo.cfg' affected_file1.LocalPath.return_value = 'foo.cfg'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment