Skip to content

Commit 79883fa

Browse files
committed
[jobs] keep track of a managed job's owner/creator
1 parent a2f3859 commit 79883fa

File tree

10 files changed

+183
-51
lines changed

10 files changed

+183
-51
lines changed

sky/cli.py

+49-15
Original file line numberDiff line numberDiff line change
@@ -1379,12 +1379,14 @@ def exec(cluster: Optional[str], cluster_option: Optional[str],
13791379
def _handle_jobs_queue_request(
13801380
request_id: str,
13811381
show_all: bool,
1382+
show_user: bool,
13821383
limit_num_jobs_to_show: bool = False,
13831384
is_called_by_user: bool = False) -> Tuple[Optional[int], str]:
13841385
"""Get the in-progress managed jobs.
13851386
13861387
Args:
13871388
show_all: Show all information of each job (e.g., region, price).
1389+
show_user: Show the user who submitted the job.
13881390
limit_num_jobs_to_show: If True, limit the number of jobs to show to
13891391
_NUM_MANAGED_JOBS_TO_SHOW_IN_STATUS, which is mainly used by
13901392
`sky status`.
@@ -1452,6 +1454,7 @@ def _handle_jobs_queue_request(
14521454
if limit_num_jobs_to_show else None)
14531455
msg = managed_jobs.format_job_table(managed_jobs_,
14541456
show_all=show_all,
1457+
show_user=show_user,
14551458
max_jobs=max_jobs_to_show)
14561459
return num_in_progress_jobs, msg
14571460

@@ -1561,7 +1564,9 @@ def _status_kubernetes(show_all: bool):
15611564
click.echo(f'\n{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
15621565
f'Managed jobs'
15631566
f'{colorama.Style.RESET_ALL}')
1564-
msg = managed_jobs.format_job_table(all_jobs, show_all=show_all)
1567+
msg = managed_jobs.format_job_table(all_jobs,
1568+
show_all=show_all,
1569+
show_user=False)
15651570
click.echo(msg)
15661571
if any(['sky-serve-controller' in c.cluster_name for c in all_clusters]):
15671572
# TODO: Parse serve controllers and show services separately.
@@ -1779,7 +1784,8 @@ def status(verbose: bool, refresh: bool, ip: bool, endpoints: bool,
17791784
show_managed_jobs = show_managed_jobs and not any([clusters, ip, endpoints])
17801785
if show_managed_jobs:
17811786
managed_jobs_queue_request_id = managed_jobs.queue(refresh=False,
1782-
skip_finished=True)
1787+
skip_finished=True,
1788+
all_users=all_users)
17831789
show_endpoints = endpoints or endpoint is not None
17841790
show_single_endpoint = endpoint is not None
17851791
show_services = show_services and not any([clusters, ip, endpoints])
@@ -1859,6 +1865,7 @@ def status(verbose: bool, refresh: bool, ip: bool, endpoints: bool,
18591865
num_in_progress_jobs, msg = _handle_jobs_queue_request(
18601866
managed_jobs_queue_request_id,
18611867
show_all=False,
1868+
show_user=False,
18621869
limit_num_jobs_to_show=not all,
18631870
is_called_by_user=False)
18641871
except KeyboardInterrupt:
@@ -2751,7 +2758,7 @@ def start(
27512758
def down(
27522759
clusters: List[str],
27532760
all: bool, # pylint: disable=redefined-builtin
2754-
all_users: bool, # pylint: disable=redefined-builtin
2761+
all_users: bool,
27552762
yes: bool,
27562763
purge: bool,
27572764
async_call: bool,
@@ -2812,7 +2819,9 @@ def _hint_or_raise_for_down_jobs_controller(controller_name: str,
28122819
with rich_utils.client_status(
28132820
'[bold cyan]Checking for in-progress managed jobs[/]'):
28142821
try:
2815-
request_id = managed_jobs.queue(refresh=False, skip_finished=True)
2822+
request_id = managed_jobs.queue(refresh=False,
2823+
skip_finished=True,
2824+
all_users=True)
28162825
managed_jobs_ = sdk.stream_and_get(request_id)
28172826
except exceptions.ClusterNotUpError as e:
28182827
if controller.value.connection_error_hint in str(e):
@@ -2836,7 +2845,9 @@ def _hint_or_raise_for_down_jobs_controller(controller_name: str,
28362845
'jobs (output of `sky jobs queue`) will be lost.')
28372846
click.echo(msg)
28382847
if managed_jobs_:
2839-
job_table = managed_jobs.format_job_table(managed_jobs_, show_all=False)
2848+
job_table = managed_jobs.format_job_table(managed_jobs_,
2849+
show_all=False,
2850+
show_user=True)
28402851
msg = controller.value.decline_down_for_dirty_controller_hint
28412852
# Add prefix to each line to align with the bullet point.
28422853
msg += '\n'.join(
@@ -3905,9 +3916,16 @@ def jobs_launch(
39053916
is_flag=True,
39063917
required=False,
39073918
help='Show only pending/running jobs\' information.')
3919+
@click.option('--all-users',
3920+
'-u',
3921+
default=False,
3922+
is_flag=True,
3923+
required=False,
3924+
help='Show jobs from all users.')
39083925
@usage_lib.entrypoint
39093926
# pylint: disable=redefined-builtin
3910-
def jobs_queue(verbose: bool, refresh: bool, skip_finished: bool):
3927+
def jobs_queue(verbose: bool, refresh: bool, skip_finished: bool,
3928+
all_users: bool):
39113929
"""Show statuses of managed jobs.
39123930
39133931
Each managed jobs can have one of the following statuses:
@@ -3964,9 +3982,10 @@ def jobs_queue(verbose: bool, refresh: bool, skip_finished: bool):
39643982
click.secho('Fetching managed job statuses...', fg='cyan')
39653983
with rich_utils.client_status('[cyan]Checking managed jobs[/]'):
39663984
managed_jobs_request_id = managed_jobs.queue(
3967-
refresh=refresh, skip_finished=skip_finished)
3985+
refresh=refresh, skip_finished=skip_finished, all_users=all_users)
39683986
_, msg = _handle_jobs_queue_request(managed_jobs_request_id,
39693987
show_all=verbose,
3988+
show_user=all_users,
39703989
is_called_by_user=True)
39713990
if not skip_finished:
39723991
in_progress_only_hint = ''
@@ -3989,16 +4008,23 @@ def jobs_queue(verbose: bool, refresh: bool, skip_finished: bool):
39894008
is_flag=True,
39904009
default=False,
39914010
required=False,
3992-
help='Cancel all managed jobs.')
4011+
help='Cancel all managed jobs for the current user.')
39934012
@click.option('--yes',
39944013
'-y',
39954014
is_flag=True,
39964015
default=False,
39974016
required=False,
39984017
help='Skip confirmation prompt.')
4018+
@click.option('--all-users',
4019+
'-u',
4020+
is_flag=True,
4021+
default=False,
4022+
required=False,
4023+
help='Cancel all managed jobs from all users.')
39994024
@usage_lib.entrypoint
40004025
# pylint: disable=redefined-builtin
4001-
def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool):
4026+
def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool,
4027+
all_users: bool):
40024028
"""Cancel managed jobs.
40034029
40044030
You can provide either a job name or a list of job IDs to be cancelled.
@@ -4015,25 +4041,33 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool):
40154041
$ sky jobs cancel 1 2 3
40164042
"""
40174043
job_id_str = ','.join(map(str, job_ids))
4018-
if sum([bool(job_ids), name is not None, all]) != 1:
4019-
argument_str = f'--job-ids {job_id_str}' if job_ids else ''
4020-
argument_str += f' --name {name}' if name is not None else ''
4021-
argument_str += ' --all' if all else ''
4044+
if sum([bool(job_ids), name is not None, all, all_users]) != 1:
4045+
arguments = []
4046+
arguments += [f'--job-ids {job_id_str}'] if job_ids else []
4047+
arguments += [f'--name {name}'] if name is not None else []
4048+
arguments += ['--all'] if all else []
4049+
arguments += ['--all-users'] if all_users else []
40224050
raise click.UsageError(
40234051
'Can only specify one of JOB_IDS or --name or --all. '
4024-
f'Provided {argument_str!r}.')
4052+
f'Provided {" ".join(arguments)!r}.')
40254053

40264054
if not yes:
40274055
job_identity_str = (f'managed jobs with IDs {job_id_str}'
40284056
if job_ids else repr(name))
40294057
if all:
40304058
job_identity_str = 'all managed jobs'
4059+
if all_users:
4060+
job_identity_str = 'all managed jobs FOR ALL USERS'
40314061
click.confirm(f'Cancelling {job_identity_str}. Proceed?',
40324062
default=True,
40334063
abort=True,
40344064
show_default=True)
40354065

4036-
sdk.stream_and_get(managed_jobs.cancel(job_ids=job_ids, name=name, all=all))
4066+
sdk.stream_and_get(
4067+
managed_jobs.cancel(job_ids=job_ids,
4068+
name=name,
4069+
all=all,
4070+
all_users=all_users))
40374071

40384072

40394073
@jobs.command('logs', cls=_DocumentedCodeCommand)

sky/jobs/client/sdk.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,16 @@ def launch(
8585
@usage_lib.entrypoint
8686
@server_common.check_server_healthy_or_start
8787
def queue(refresh: bool,
88-
skip_finished: bool = False) -> server_common.RequestId:
88+
skip_finished: bool = False,
89+
all_users: bool = False) -> server_common.RequestId:
8990
"""Gets statuses of managed jobs.
9091
9192
Please refer to sky.cli.job_queue for documentation.
9293
9394
Args:
9495
refresh: Whether to restart the jobs controller if it is stopped.
9596
skip_finished: Whether to skip finished jobs.
97+
all_users: Whether to show all users' jobs.
9698
9799
Returns:
98100
The request ID of the queue request.
@@ -126,6 +128,7 @@ def queue(refresh: bool,
126128
body = payloads.JobsQueueBody(
127129
refresh=refresh,
128130
skip_finished=skip_finished,
131+
all_users=all_users,
129132
)
130133
response = requests.post(
131134
f'{server_common.get_server_url()}/jobs/queue',
@@ -138,9 +141,10 @@ def queue(refresh: bool,
138141
@usage_lib.entrypoint
139142
@server_common.check_server_healthy_or_start
140143
def cancel(
141-
name: Optional[str] = None,
142-
job_ids: Optional[List[int]] = None,
143-
all: bool = False, # pylint: disable=redefined-builtin
144+
name: Optional[str] = None,
145+
job_ids: Optional[List[int]] = None,
146+
all: bool = False, # pylint: disable=redefined-builtin
147+
all_users: bool = False,
144148
) -> server_common.RequestId:
145149
"""Cancels managed jobs.
146150
@@ -150,6 +154,7 @@ def cancel(
150154
name: Name of the managed job to cancel.
151155
job_ids: IDs of the managed jobs to cancel.
152156
all: Whether to cancel all managed jobs.
157+
all_users: Whether to cancel all managed jobs from all users.
153158
154159
Returns:
155160
The request ID of the cancel request.
@@ -162,6 +167,7 @@ def cancel(
162167
name=name,
163168
job_ids=job_ids,
164169
all=all,
170+
all_users=all_users,
165171
)
166172
response = requests.post(
167173
f'{server_common.get_server_url()}/jobs/cancel',

sky/jobs/dashboard/dashboard.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import yaml
1717

1818
from sky import jobs as managed_jobs
19+
from sky.client import sdk
1920
from sky.jobs import constants as managed_job_constants
2021
from sky.utils import common_utils
2122
from sky.utils import controller_utils
@@ -134,14 +135,16 @@ def _extract_launch_history(log_content: str) -> str:
134135
def home():
135136
if not _is_running_on_jobs_controller():
136137
# Experimental: run on laptop (refresh is very slow).
137-
all_managed_jobs = managed_jobs.queue(refresh=True, skip_finished=False)
138+
request_id = managed_jobs.queue(refresh=True, skip_finished=False)
139+
all_managed_jobs = sdk.get(request_id)
138140
else:
139141
job_table = managed_jobs.dump_managed_job_queue()
140142
all_managed_jobs = managed_jobs.load_managed_job_queue(job_table)
141143

142144
timestamp = datetime.datetime.now(datetime.timezone.utc)
143145
rows = managed_jobs.format_job_table(all_managed_jobs,
144146
show_all=True,
147+
show_user=False,
145148
return_rows=True)
146149

147150
status_counts = collections.defaultdict(int)

sky/jobs/scheduler.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from sky.jobs import constants as managed_job_constants
5050
from sky.jobs import state
5151
from sky.skylet import constants
52+
from sky.utils import common_utils
5253
from sky.utils import subprocess_utils
5354

5455
logger = sky_logging.init_logger('sky.jobs.controller')
@@ -190,9 +191,12 @@ def submit_job(job_id: int, dag_yaml_path: str, env_file_path: str) -> None:
190191
PENDING. It will tell the scheduler to try and start the job controller, if
191192
there are resources available. It may block to acquire the lock, so it
192193
should not be on the critical path for `sky jobs launch -d`.
194+
195+
The user hash should be set (e.g. via SKYPILOT_USER_ID) before calling this.
193196
"""
194197
with filelock.FileLock(_get_lock_path()):
195-
state.scheduler_set_waiting(job_id, dag_yaml_path, env_file_path)
198+
state.scheduler_set_waiting(job_id, dag_yaml_path, env_file_path,
199+
common_utils.get_user_hash())
196200
maybe_schedule_next_jobs()
197201

198202

sky/jobs/server/core.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,9 @@ def _maybe_restart_controller(
346346

347347

348348
@usage_lib.entrypoint
349-
def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
349+
def queue(refresh: bool,
350+
skip_finished: bool = False,
351+
all_users: bool = False) -> List[Dict[str, Any]]:
350352
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
351353
"""Gets statuses of managed jobs.
352354
@@ -394,6 +396,18 @@ def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
394396
f'{returncode}')
395397

396398
jobs = managed_job_utils.load_managed_job_queue(job_table_payload)
399+
400+
if not all_users:
401+
# For backwards compatibility, we show jobs that have user_hash None.
402+
def user_hash_matches_or_missing(job: Dict[str, Any]) -> bool:
403+
if 'user_hash' not in job:
404+
return True
405+
if job['user_hash'] is None:
406+
return True
407+
return job['user_hash'] == common_utils.get_user_hash()
408+
409+
jobs = list(filter(user_hash_matches_or_missing, jobs))
410+
397411
if skip_finished:
398412
# Filter out the finished jobs. If a multi-task job is partially
399413
# finished, we will include all its tasks.
@@ -402,14 +416,16 @@ def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]:
402416
non_finished_job_ids = {job['job_id'] for job in non_finished_tasks}
403417
jobs = list(
404418
filter(lambda job: job['job_id'] in non_finished_job_ids, jobs))
419+
405420
return jobs
406421

407422

408423
@usage_lib.entrypoint
409424
# pylint: disable=redefined-builtin
410425
def cancel(name: Optional[str] = None,
411426
job_ids: Optional[List[int]] = None,
412-
all: bool = False) -> None:
427+
all: bool = False,
428+
all_users: bool = False) -> None:
413429
# NOTE(dev): Keep the docstring consistent between the Python API and CLI.
414430
"""Cancels managed jobs.
415431
@@ -425,18 +441,23 @@ def cancel(name: Optional[str] = None,
425441
stopped_message='All managed jobs should have finished.')
426442

427443
job_id_str = ','.join(map(str, job_ids))
428-
if sum([bool(job_ids), name is not None, all]) != 1:
429-
argument_str = f'job_ids={job_id_str}' if job_ids else ''
430-
argument_str += f' name={name}' if name is not None else ''
431-
argument_str += ' all' if all else ''
444+
if sum([bool(job_ids), name is not None, all, all_users]) != 1:
445+
arguments = []
446+
arguments += [f'job_ids={job_id_str}'] if job_ids else []
447+
arguments += [f'name={name}'] if name is not None else []
448+
arguments += ['all'] if all else []
449+
arguments += ['all_users'] if all_users else []
432450
with ux_utils.print_exception_no_traceback():
433451
raise ValueError('Can only specify one of JOB_IDS or name or all. '
434-
f'Provided {argument_str!r}.')
452+
f'Provided {" ".join(arguments)!r}.')
435453

436454
backend = backend_utils.get_backend_from_handle(handle)
437455
assert isinstance(backend, backends.CloudVmRayBackend)
438456
if all:
439457
code = managed_job_utils.ManagedJobCodeGen.cancel_jobs_by_id(None)
458+
elif all_users:
459+
code = managed_job_utils.ManagedJobCodeGen.cancel_jobs_by_id(
460+
None, all_users=True)
440461
elif job_ids:
441462
code = managed_job_utils.ManagedJobCodeGen.cancel_jobs_by_id(job_ids)
442463
else:

sky/jobs/server/server.py

+9
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,18 @@ async def download_logs(
109109
@router.get('/dashboard')
110110
async def dashboard(request: fastapi.Request,
111111
user_hash: str) -> fastapi.Response:
112+
# TODO(cooperc): Support showing only jobs for a specific user.
113+
114+
# FIX(zhwu/cooperc/eric): Fix log downloading (assumes global
115+
# /download_log/xx route)
116+
112117
# Note: before #4717, each user had their own controller, and thus their own
113118
# dashboard. Now, all users share the same controller, so this isn't really
114119
# necessary. TODO(cooperc): clean up.
120+
121+
# TODO: Put this in an executor to avoid blocking the main server thread.
122+
# It can take a long time if it needs to check the controller status.
123+
115124
# Find the port for the dashboard of the user
116125
os.environ[constants.USER_ID_ENV_VAR] = user_hash
117126
server_common.reload_for_new_request(client_entrypoint=None,

0 commit comments

Comments
 (0)