Skip to content

Firebase Ml Fix upload file naming #392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 27, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions firebase_admin/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,25 @@
}


_RPC_CODE_TO_ERROR_CODE = {
1: exceptions.CANCELLED,
2: exceptions.UNKNOWN,
3: exceptions.INVALID_ARGUMENT,
4: exceptions.DEADLINE_EXCEEDED,
5: exceptions.NOT_FOUND,
6: exceptions.ALREADY_EXISTS,
7: exceptions.PERMISSION_DENIED,
8: exceptions.RESOURCE_EXHAUSTED,
9: exceptions.FAILED_PRECONDITION,
10: exceptions.ABORTED,
11: exceptions.OUT_OF_RANGE,
13: exceptions.INTERNAL,
14: exceptions.UNAVAILABLE,
15: exceptions.DATA_LOSS,
16: exceptions.UNAUTHENTICATED,
}


def _get_initialized_app(app):
if app is None:
return firebase_admin.get_app()
Expand Down Expand Up @@ -120,9 +139,9 @@ def handle_operation_error(error):
message='Unknown error while making a remote service call: {0}'.format(error),
cause=error)

status_code = error.get('code')
rpc_code = error.get('code')
message = error.get('message')
error_code = _http_status_to_error_code(status_code)
error_code = _rpc_code_to_error_code(rpc_code)
err_type = _error_code_to_exception_type(error_code)
return err_type(message=message)

Expand Down Expand Up @@ -283,6 +302,9 @@ def _http_status_to_error_code(status):
"""Maps an HTTP status to a platform error code."""
return _HTTP_STATUS_TO_ERROR_CODE.get(status, exceptions.UNKNOWN)

def _rpc_code_to_error_code(rpc_code):
"""Maps an RPC code to a platform error code."""
return _RPC_CODE_TO_ERROR_CODE.get(rpc_code, exceptions.UNKNOWN)

def _error_code_to_exception_type(code):
"""Maps a platform error code to an exception type."""
Expand Down
25 changes: 13 additions & 12 deletions firebase_admin/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


import datetime
import numbers
import re
import time
import requests
Expand Down Expand Up @@ -246,20 +245,12 @@ def display_name(self, display_name):
@property
def create_time(self):
"""The time the model was created."""
seconds = self._data.get('createTime', {}).get('seconds')
if not isinstance(seconds, numbers.Number):
return None

return datetime.datetime.fromtimestamp(float(seconds))
return self._data.get('createTime', None)

@property
def update_time(self):
"""The time the model was last updated."""
seconds = self._data.get('updateTime', {}).get('seconds')
if not isinstance(seconds, numbers.Number):
return None

return datetime.datetime.fromtimestamp(float(seconds))
return self._data.get('updateTime', None)

@property
def validation_error(self):
Expand Down Expand Up @@ -439,8 +430,18 @@ def _parse_gcs_tflite_uri(uri):
def upload(bucket_name, model_file_name, app):
"""Upload a model file to the specified Storage bucket."""
_CloudStorageClient._assert_gcs_enabled()

# Calculate the destination file_name (remove path if present)
file_name = model_file_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole thing can be condensed to:

file_name = os.path.basename(model_file_name)

That should work in Windows environments too (that's probably not a requirement, but just in case)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! done.

file_name_pattern = re.compile(r'^(?P<path>.*)/(?P<file_name>[^/]+)$')
matcher = file_name_pattern.match(model_file_name)
if matcher:
# The model_file_name contains at least one '/'
# ignore the path and just keep the file_name
file_name = matcher.group('file_name')

bucket = storage.bucket(bucket_name, app=app)
blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name)
blob_name = _CloudStorageClient.BLOB_NAME.format(file_name)
blob = bucket.blob(blob_name)
blob.upload_from_filename(model_file_name)
return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pytest-cov >= 2.4.0
pytest-localserver >= 0.4.1
tox >= 3.6.0

cachecontrol >= 0.12.4
cachecontrol >= 0.12.6
google-api-core[grpc] >= 1.7.0, < 2.0.0dev; platform.python_implementation != 'PyPy'
google-api-python-client >= 1.7.8
google-cloud-firestore >= 0.31.0; platform.python_implementation != 'PyPy'
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
long_description = ('The Firebase Admin Python SDK enables server-side (backend) Python developers '
'to integrate Firebase into their services and applications.')
install_requires = [
'cachecontrol>=0.12.4',
'cachecontrol>=0.12.6',
'google-api-core[grpc] >= 1.7.0, < 2.0.0dev; platform.python_implementation != "PyPy"',
'google-api-python-client >= 1.7.8',
'google-cloud-firestore>=0.31.0; platform.python_implementation != "PyPy"',
Expand Down
2 changes: 1 addition & 1 deletion tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ def test_http_timeout(self):
assert ref._client.timeout == 60
assert ref.get() == {}
assert len(recorder) == 1
assert recorder[0]._extra_kwargs['timeout'] == 60
assert recorder[0]._extra_kwargs['timeout'] == pytest.approx(60, 0.001)

def test_app_delete(self):
app = firebase_admin.initialize_app(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,7 @@ def test_send(self):
msg = messaging.Message(topic='foo')
messaging.send(msg)
assert len(self.recorder) == 1
assert self.recorder[0]._extra_kwargs['timeout'] == 4
assert self.recorder[0]._extra_kwargs['timeout'] == pytest.approx(4, 0.001)

def test_topic_management_timeout(self):
self.fcm_service._client.session.mount(
Expand All @@ -1266,7 +1266,7 @@ def test_topic_management_timeout(self):
)
messaging.subscribe_to_topic(['1'], 'a')
assert len(self.recorder) == 1
assert self.recorder[0]._extra_kwargs['timeout'] == 4
assert self.recorder[0]._extra_kwargs['timeout'] == pytest.approx(4, 0.001)


class TestSend(object):
Expand Down
54 changes: 20 additions & 34 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

"""Test cases for the firebase_admin.ml module."""

import datetime
import json
import pytest

Expand All @@ -27,25 +26,12 @@
PROJECT_ID = 'myProject1'
PAGE_TOKEN = 'pageToken'
NEXT_PAGE_TOKEN = 'nextPageToken'
CREATE_TIME_SECONDS = 1566426374
CREATE_TIME_SECONDS_2 = 1566426385
CREATE_TIME_JSON = {
'seconds': CREATE_TIME_SECONDS
}
CREATE_TIME_DATETIME = datetime.datetime.fromtimestamp(CREATE_TIME_SECONDS)
CREATE_TIME_JSON_2 = {
'seconds': CREATE_TIME_SECONDS_2
}
CREATE_TIME = '2020-01-21T20:44:27.392932Z'
CREATE_TIME_2 = '2020-01-21T21:44:27.392932Z'

UPDATE_TIME = '2020-01-21T22:45:29.392932Z'
UPDATE_TIME_2 = '2020-01-21T23:45:29.392932Z'

UPDATE_TIME_SECONDS = 1566426678
UPDATE_TIME_SECONDS_2 = 1566426691
UPDATE_TIME_JSON = {
'seconds': UPDATE_TIME_SECONDS
}
UPDATE_TIME_DATETIME = datetime.datetime.fromtimestamp(UPDATE_TIME_SECONDS)
UPDATE_TIME_JSON_2 = {
'seconds': UPDATE_TIME_SECONDS_2
}
ETAG = '33a64df551425fcc55e4d42a148795d9f25f89d4'
MODEL_HASH = '987987a98b98798d098098e09809fc0893897'
TAG_1 = 'Tag1'
Expand Down Expand Up @@ -130,8 +116,8 @@
CREATED_UPDATED_MODEL_JSON_1 = {
'name': MODEL_NAME_1,
'displayName': DISPLAY_NAME_1,
'createTime': CREATE_TIME_JSON,
'updateTime': UPDATE_TIME_JSON,
'createTime': CREATE_TIME,
'updateTime': UPDATE_TIME,
'state': MODEL_STATE_ERROR_JSON,
'etag': ETAG,
'modelHash': MODEL_HASH,
Expand All @@ -142,17 +128,17 @@
LOCKED_MODEL_JSON_1 = {
'name': MODEL_NAME_1,
'displayName': DISPLAY_NAME_1,
'createTime': CREATE_TIME_JSON,
'updateTime': UPDATE_TIME_JSON,
'createTime': CREATE_TIME,
'updateTime': UPDATE_TIME,
'tags': TAGS,
'activeOperations': [OPERATION_NOT_DONE_JSON_1]
}

LOCKED_MODEL_JSON_2 = {
'name': MODEL_NAME_1,
'displayName': DISPLAY_NAME_2,
'createTime': CREATE_TIME_JSON_2,
'updateTime': UPDATE_TIME_JSON_2,
'createTime': CREATE_TIME_2,
'updateTime': UPDATE_TIME_2,
'tags': TAGS_2,
'activeOperations': [OPERATION_NOT_DONE_JSON_1]
}
Expand All @@ -169,7 +155,7 @@
# Name is required if the operation is not done.
'done': False
}
OPERATION_ERROR_CODE = 400
OPERATION_ERROR_CODE = 3
OPERATION_ERROR_MSG = "Invalid argument"
OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT'
OPERATION_ERROR_JSON_1 = {
Expand All @@ -183,8 +169,8 @@
FULL_MODEL_ERR_STATE_LRO_JSON = {
'name': MODEL_NAME_1,
'displayName': DISPLAY_NAME_1,
'createTime': CREATE_TIME_JSON,
'updateTime': UPDATE_TIME_JSON,
'createTime': CREATE_TIME,
'updateTime': UPDATE_TIME,
'state': MODEL_STATE_ERROR_JSON,
'etag': ETAG,
'modelHash': MODEL_HASH,
Expand All @@ -194,8 +180,8 @@
FULL_MODEL_PUBLISHED_JSON = {
'name': MODEL_NAME_1,
'displayName': DISPLAY_NAME_1,
'createTime': CREATE_TIME_JSON,
'updateTime': UPDATE_TIME_JSON,
'createTime': CREATE_TIME,
'updateTime': UPDATE_TIME,
'state': MODEL_STATE_PUBLISHED_JSON,
'etag': ETAG,
'modelHash': MODEL_HASH,
Expand Down Expand Up @@ -364,8 +350,8 @@ def test_model_success_err_state_lro(self):
model = ml.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON)
assert model.model_id == MODEL_ID_1
assert model.display_name == DISPLAY_NAME_1
assert model.create_time == CREATE_TIME_DATETIME
assert model.update_time == UPDATE_TIME_DATETIME
assert model.create_time == CREATE_TIME
assert model.update_time == UPDATE_TIME
assert model.validation_error == VALIDATION_ERROR_MSG
assert model.published is False
assert model.etag == ETAG
Expand All @@ -379,8 +365,8 @@ def test_model_success_published(self):
model = ml.Model.from_dict(FULL_MODEL_PUBLISHED_JSON)
assert model.model_id == MODEL_ID_1
assert model.display_name == DISPLAY_NAME_1
assert model.create_time == CREATE_TIME_DATETIME
assert model.update_time == UPDATE_TIME_DATETIME
assert model.create_time == CREATE_TIME
assert model.update_time == UPDATE_TIME
assert model.validation_error is None
assert model.published is True
assert model.etag == ETAG
Expand Down