Skip to content

[jobs] multi-user managed jobs #4787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 26, 2025
64 changes: 49 additions & 15 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,12 +1379,14 @@ def exec(cluster: Optional[str], cluster_option: Optional[str],
def _handle_jobs_queue_request(
request_id: str,
show_all: bool,
show_user: bool,
limit_num_jobs_to_show: bool = False,
is_called_by_user: bool = False) -> Tuple[Optional[int], str]:
"""Get the in-progress managed jobs.

Args:
show_all: Show all information of each job (e.g., region, price).
show_user: Show the user who submitted the job.
limit_num_jobs_to_show: If True, limit the number of jobs to show to
_NUM_MANAGED_JOBS_TO_SHOW_IN_STATUS, which is mainly used by
`sky status`.
Expand Down Expand Up @@ -1452,6 +1454,7 @@ def _handle_jobs_queue_request(
if limit_num_jobs_to_show else None)
msg = managed_jobs.format_job_table(managed_jobs_,
show_all=show_all,
show_user=show_user,
max_jobs=max_jobs_to_show)
return num_in_progress_jobs, msg

Expand Down Expand Up @@ -1561,7 +1564,9 @@ def _status_kubernetes(show_all: bool):
click.echo(f'\n{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
f'Managed jobs'
f'{colorama.Style.RESET_ALL}')
msg = managed_jobs.format_job_table(all_jobs, show_all=show_all)
msg = managed_jobs.format_job_table(all_jobs,
show_all=show_all,
show_user=False)
click.echo(msg)
if any(['sky-serve-controller' in c.cluster_name for c in all_clusters]):
# TODO: Parse serve controllers and show services separately.
Expand Down Expand Up @@ -1779,7 +1784,8 @@ def status(verbose: bool, refresh: bool, ip: bool, endpoints: bool,
show_managed_jobs = show_managed_jobs and not any([clusters, ip, endpoints])
if show_managed_jobs:
managed_jobs_queue_request_id = managed_jobs.queue(refresh=False,
skip_finished=True)
skip_finished=True,
all_users=all_users)
show_endpoints = endpoints or endpoint is not None
show_single_endpoint = endpoint is not None
show_services = show_services and not any([clusters, ip, endpoints])
Expand Down Expand Up @@ -1859,6 +1865,7 @@ def status(verbose: bool, refresh: bool, ip: bool, endpoints: bool,
num_in_progress_jobs, msg = _handle_jobs_queue_request(
managed_jobs_queue_request_id,
show_all=False,
show_user=False,
limit_num_jobs_to_show=not all,
is_called_by_user=False)
except KeyboardInterrupt:
Expand Down Expand Up @@ -2751,7 +2758,7 @@ def start(
def down(
clusters: List[str],
all: bool, # pylint: disable=redefined-builtin
all_users: bool, # pylint: disable=redefined-builtin
all_users: bool,
yes: bool,
purge: bool,
async_call: bool,
Expand Down Expand Up @@ -2812,7 +2819,9 @@ def _hint_or_raise_for_down_jobs_controller(controller_name: str,
with rich_utils.client_status(
'[bold cyan]Checking for in-progress managed jobs[/]'):
try:
request_id = managed_jobs.queue(refresh=False, skip_finished=True)
request_id = managed_jobs.queue(refresh=False,
skip_finished=True,
all_users=True)
managed_jobs_ = sdk.stream_and_get(request_id)
except exceptions.ClusterNotUpError as e:
if controller.value.connection_error_hint in str(e):
Expand All @@ -2836,7 +2845,9 @@ def _hint_or_raise_for_down_jobs_controller(controller_name: str,
'jobs (output of `sky jobs queue`) will be lost.')
click.echo(msg)
if managed_jobs_:
job_table = managed_jobs.format_job_table(managed_jobs_, show_all=False)
job_table = managed_jobs.format_job_table(managed_jobs_,
show_all=False,
show_user=True)
msg = controller.value.decline_down_for_dirty_controller_hint
# Add prefix to each line to align with the bullet point.
msg += '\n'.join(
Expand Down Expand Up @@ -3905,9 +3916,16 @@ def jobs_launch(
is_flag=True,
required=False,
help='Show only pending/running jobs\' information.')
@click.option('--all-users',
'-u',
default=False,
is_flag=True,
required=False,
help='Show jobs from all users.')
@usage_lib.entrypoint
# pylint: disable=redefined-builtin
def jobs_queue(verbose: bool, refresh: bool, skip_finished: bool):
def jobs_queue(verbose: bool, refresh: bool, skip_finished: bool,
all_users: bool):
"""Show statuses of managed jobs.

Each managed jobs can have one of the following statuses:
Expand Down Expand Up @@ -3964,9 +3982,10 @@ def jobs_queue(verbose: bool, refresh: bool, skip_finished: bool):
click.secho('Fetching managed job statuses...', fg='cyan')
with rich_utils.client_status('[cyan]Checking managed jobs[/]'):
managed_jobs_request_id = managed_jobs.queue(
refresh=refresh, skip_finished=skip_finished)
refresh=refresh, skip_finished=skip_finished, all_users=all_users)
_, msg = _handle_jobs_queue_request(managed_jobs_request_id,
show_all=verbose,
show_user=all_users,
is_called_by_user=True)
if not skip_finished:
in_progress_only_hint = ''
Expand All @@ -3989,16 +4008,23 @@ def jobs_queue(verbose: bool, refresh: bool, skip_finished: bool):
is_flag=True,
default=False,
required=False,
help='Cancel all managed jobs.')
help='Cancel all managed jobs for the current user.')
@click.option('--yes',
'-y',
is_flag=True,
default=False,
required=False,
help='Skip confirmation prompt.')
@click.option('--all-users',
'-u',
is_flag=True,
default=False,
required=False,
help='Cancel all managed jobs from all users.')
@usage_lib.entrypoint
# pylint: disable=redefined-builtin
def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool):
def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool,
all_users: bool):
"""Cancel managed jobs.

You can provide either a job name or a list of job IDs to be cancelled.
Expand All @@ -4015,25 +4041,33 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool):
$ sky jobs cancel 1 2 3
"""
job_id_str = ','.join(map(str, job_ids))
if sum([bool(job_ids), name is not None, all]) != 1:
argument_str = f'--job-ids {job_id_str}' if job_ids else ''
argument_str += f' --name {name}' if name is not None else ''
argument_str += ' --all' if all else ''
if sum([bool(job_ids), name is not None, all, all_users]) != 1:
arguments = []
arguments += [f'--job-ids {job_id_str}'] if job_ids else []
arguments += [f'--name {name}'] if name is not None else []
arguments += ['--all'] if all else []
arguments += ['--all-users'] if all_users else []
raise click.UsageError(
'Can only specify one of JOB_IDS or --name or --all. '
f'Provided {argument_str!r}.')
f'Provided {" ".join(arguments)!r}.')

if not yes:
job_identity_str = (f'managed jobs with IDs {job_id_str}'
if job_ids else repr(name))
if all:
job_identity_str = 'all managed jobs'
if all_users:
job_identity_str = 'all managed jobs FOR ALL USERS'
click.confirm(f'Cancelling {job_identity_str}. Proceed?',
default=True,
abort=True,
show_default=True)

sdk.stream_and_get(managed_jobs.cancel(job_ids=job_ids, name=name, all=all))
sdk.stream_and_get(
managed_jobs.cancel(job_ids=job_ids,
name=name,
all=all,
all_users=all_users))


@jobs.command('logs', cls=_DocumentedCodeCommand)
Expand Down
14 changes: 10 additions & 4 deletions sky/jobs/client/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,16 @@ def launch(
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
def queue(refresh: bool,
skip_finished: bool = False) -> server_common.RequestId:
skip_finished: bool = False,
all_users: bool = False) -> server_common.RequestId:
"""Gets statuses of managed jobs.

Please refer to sky.cli.job_queue for documentation.

Args:
refresh: Whether to restart the jobs controller if it is stopped.
skip_finished: Whether to skip finished jobs.
all_users: Whether to show all users' jobs.

Returns:
The request ID of the queue request.
Expand Down Expand Up @@ -126,6 +128,7 @@ def queue(refresh: bool,
body = payloads.JobsQueueBody(
refresh=refresh,
skip_finished=skip_finished,
all_users=all_users,
)
response = requests.post(
f'{server_common.get_server_url()}/jobs/queue',
Expand All @@ -138,9 +141,10 @@ def queue(refresh: bool,
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
def cancel(
name: Optional[str] = None,
job_ids: Optional[List[int]] = None,
all: bool = False, # pylint: disable=redefined-builtin
name: Optional[str] = None,
job_ids: Optional[List[int]] = None,
all: bool = False, # pylint: disable=redefined-builtin
all_users: bool = False,
) -> server_common.RequestId:
"""Cancels managed jobs.

Expand All @@ -150,6 +154,7 @@ def cancel(
name: Name of the managed job to cancel.
job_ids: IDs of the managed jobs to cancel.
all: Whether to cancel all managed jobs.
all_users: Whether to cancel all managed jobs from all users.

Returns:
The request ID of the cancel request.
Expand All @@ -162,6 +167,7 @@ def cancel(
name=name,
job_ids=job_ids,
all=all,
all_users=all_users,
)
response = requests.post(
f'{server_common.get_server_url()}/jobs/cancel',
Expand Down
2 changes: 1 addition & 1 deletion sky/jobs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
# The version of the lib files that jobs/utils use. Whenever there is an API
# change for the jobs/utils, we need to bump this version and update
# job.utils.ManagedJobCodeGen to handle the version update.
MANAGED_JOBS_VERSION = 1
MANAGED_JOBS_VERSION = 2

# The command for setting up the jobs dashboard on the controller. It firstly
# checks if the systemd services are available, and if not (e.g., Kubernetes
Expand Down
5 changes: 4 additions & 1 deletion sky/jobs/dashboard/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import yaml

from sky import jobs as managed_jobs
from sky.client import sdk
from sky.jobs import constants as managed_job_constants
from sky.utils import common_utils
from sky.utils import controller_utils
Expand Down Expand Up @@ -134,14 +135,16 @@ def _extract_launch_history(log_content: str) -> str:
def home():
if not _is_running_on_jobs_controller():
# Experimental: run on laptop (refresh is very slow).
all_managed_jobs = managed_jobs.queue(refresh=True, skip_finished=False)
request_id = managed_jobs.queue(refresh=True, skip_finished=False)
all_managed_jobs = sdk.get(request_id)
else:
job_table = managed_jobs.dump_managed_job_queue()
all_managed_jobs = managed_jobs.load_managed_job_queue(job_table)

timestamp = datetime.datetime.now(datetime.timezone.utc)
rows = managed_jobs.format_job_table(all_managed_jobs,
show_all=True,
show_user=False,
return_rows=True)

status_counts = collections.defaultdict(int)
Expand Down
31 changes: 23 additions & 8 deletions sky/jobs/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from sky.jobs import constants as managed_job_constants
from sky.jobs import state
from sky.skylet import constants
from sky.utils import common_utils
from sky.utils import subprocess_utils

logger = sky_logging.init_logger('sky.jobs.controller')
Expand Down Expand Up @@ -151,12 +152,20 @@ def maybe_schedule_next_jobs() -> None:
job_id = maybe_next_job['job_id']
dag_yaml_path = maybe_next_job['dag_yaml_path']

activate_python_env_cmd = (
f'{constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV};')
env_file = maybe_next_job['env_file_path']
source_environment_cmd = (f'source {env_file};'
if env_file else '')
run_controller_cmd = ('python -u -m sky.jobs.controller '
f'{dag_yaml_path} --job-id {job_id};')

# If the command line here is changed, please also update
# utils._controller_process_alive. `--job-id X` should be at
# the end.
run_cmd = (f'{constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV};'
'python -u -m sky.jobs.controller '
f'{dag_yaml_path} --job-id {job_id}')
run_cmd = (f'{activate_python_env_cmd}'
f'{source_environment_cmd}'
f'{run_controller_cmd}')

logs_dir = os.path.expanduser(
managed_job_constants.JOBS_CONTROLLER_LOGS_DIR)
Expand All @@ -175,16 +184,19 @@ def maybe_schedule_next_jobs() -> None:
pass


def submit_job(job_id: int, dag_yaml_path: str) -> None:
def submit_job(job_id: int, dag_yaml_path: str, env_file_path: str) -> None:
"""Submit an existing job to the scheduler.

This should be called after a job is created in the `spot` table as
PENDING. It will tell the scheduler to try and start the job controller, if
there are resources available. It may block to acquire the lock, so it
should not be on the critical path for `sky jobs launch -d`.

The user hash should be set (e.g. via SKYPILOT_USER_ID) before calling this.
"""
with filelock.FileLock(_get_lock_path()):
state.scheduler_set_waiting(job_id, dag_yaml_path)
state.scheduler_set_waiting(job_id, dag_yaml_path, env_file_path,
common_utils.get_user_hash())
maybe_schedule_next_jobs()


Expand Down Expand Up @@ -281,12 +293,15 @@ def _can_lauch_in_alive_job() -> bool:

if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('dag_yaml',
type=str,
help='The path to the user job yaml file.')
parser.add_argument('--job-id',
required=True,
type=int,
help='Job id for the controller job.')
parser.add_argument('dag_yaml',
parser.add_argument('--env-file',
type=str,
help='The path to the user job yaml file.')
help='The path to the controller env file.')
args = parser.parse_args()
submit_job(args.job_id, args.dag_yaml)
submit_job(args.job_id, args.dag_yaml, args.env_file)
Loading
Loading