Skip to content

Quick pass at filling in missing docstrings #367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 18, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __init__(self, display_name=None, tags=None, model_format=None):

@classmethod
def from_dict(cls, data, app=None):
"""Create an instance of the object from a dict."""
data_copy = dict(data)
tflite_format = None
tflite_format_data = data_copy.pop('tfliteModel', None)
Expand Down Expand Up @@ -223,13 +224,16 @@ def __ne__(self, other):

@property
def model_id(self):
"""The model's ID, unique to the project."""
if not self._data.get('name'):
return None
_, model_id = _validate_and_parse_name(self._data.get('name'))
return model_id

@property
def display_name(self):
"""The model's display name, used to refer to the model in code and in
the Firebase console."""
return self._data.get('displayName')

@display_name.setter
Expand All @@ -239,7 +243,7 @@ def display_name(self, display_name):

@property
def create_time(self):
"""Returns the creation timestamp"""
"""The time the model was created."""
seconds = self._data.get('createTime', {}).get('seconds')
if not isinstance(seconds, numbers.Number):
return None
Expand All @@ -248,7 +252,7 @@ def create_time(self):

@property
def update_time(self):
"""Returns the last update timestamp"""
"""The time the model was last updated."""
seconds = self._data.get('updateTime', {}).get('seconds')
if not isinstance(seconds, numbers.Number):
return None
Expand All @@ -257,22 +261,28 @@ def update_time(self):

@property
def validation_error(self):
"""Validation error message."""
return self._data.get('state', {}).get('validationError', {}).get('message')

@property
def published(self):
"""True if the model is published and available for clients to
download."""
return bool(self._data.get('state', {}).get('published'))

@property
def etag(self):
"""The entity tag (ETag) of the model resource."""
return self._data.get('etag')

@property
def model_hash(self):
"""SHA256 hash of the model binary."""
return self._data.get('modelHash')

@property
def tags(self):
"""Tag strings, used for filtering query results."""
return self._data.get('tags')

@tags.setter
Expand All @@ -282,6 +292,7 @@ def tags(self, tags):

@property
def locked(self):
"""True if the Model object is locked by an active operation."""
return bool(self._data.get('activeOperations') and
len(self._data.get('activeOperations')) > 0)

Expand All @@ -307,6 +318,8 @@ def wait_for_unlocked(self, max_time_seconds=None):

@property
def model_format(self):
"""The model's ``ModelFormat`` object, which represents the model's
format and storage location."""
return self._model_format

@model_format.setter
Expand All @@ -317,6 +330,7 @@ def model_format(self, model_format):
return self

def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
copy = dict(self._data)
if self._model_format:
copy.update(self._model_format.as_dict(for_upload=for_upload))
Expand All @@ -326,6 +340,7 @@ def as_dict(self, for_upload=False):
class ModelFormat(object):
"""Abstract base class representing a Model Format such as TFLite."""
def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
raise NotImplementedError


Expand All @@ -344,6 +359,7 @@ def __init__(self, model_source=None):

@classmethod
def from_dict(cls, data):
"""Create an instance of the object from a dict."""
data_copy = dict(data)
model_source = None
gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None)
Expand All @@ -366,6 +382,7 @@ def __ne__(self, other):

@property
def model_source(self):
"""The TF Lite model's location."""
return self._model_source

@model_source.setter
Expand All @@ -377,9 +394,11 @@ def model_source(self, model_source):

@property
def size_bytes(self):
"""The size in bytes of the TF Lite model."""
return self._data.get('sizeBytes')

def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
copy = dict(self._data)
if self._model_source:
copy.update(self._model_source.as_dict(for_upload=for_upload))
Expand All @@ -389,6 +408,7 @@ def as_dict(self, for_upload=False):
class TFLiteModelSource(object):
"""Abstract base class representing a model source for TFLite format models."""
def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
raise NotImplementedError


Expand All @@ -415,6 +435,7 @@ def _parse_gcs_tflite_uri(uri):

@staticmethod
def upload(bucket_name, model_file_name, app):
"""Upload a model file to the specified Storage bucket."""
_CloudStorageClient._assert_gcs_enabled()
bucket = storage.bucket(bucket_name, app=app)
blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name)
Expand Down Expand Up @@ -531,6 +552,7 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None):

@property
def gcs_tflite_uri(self):
"""URI of the model file in Cloud Storage."""
return self._gcs_tflite_uri

@gcs_tflite_uri.setter
Expand All @@ -542,6 +564,7 @@ def _get_signed_gcs_tflite_uri(self):
return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app)

def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
if for_upload:
return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()}

Expand Down Expand Up @@ -578,11 +601,12 @@ def list_filter(self):

@property
def next_page_token(self):
"""Token identifying the next page of results."""
return self._list_response.get('nextPageToken', '')

@property
def has_next_page(self):
"""A boolean indicating whether more pages are available."""
"""True if more pages are available."""
return bool(self.next_page_token)

def get_next_page(self):
Expand Down