"""
Client API for the Inference microservice.
"""
from concurrent.futures import ThreadPoolExecutor
from typing import List, Union
from requests import RequestException
from sap.aibus.dar.client.base_client import BaseClientWithSession
from sap.aibus.dar.client.exceptions import DARHTTPException, InvalidWorkerCount
from sap.aibus.dar.client.inference_constants import InferencePaths
from sap.aibus.dar.client.util.lists import split_list
#: How many objects can be processed per inference request
LIMIT_OBJECTS_PER_CALL = 50
#: How many labels to predict for a single object by default
TOP_N = 1
# pylint: disable=too-many-arguments
[docs]class InferenceClient(BaseClientWithSession):
"""
A client for the DAR Inference microservice.
This class implements all basic API calls as well as some convenience methods
which wrap individual API calls.
If the API call fails, all methods will raise an :exc:`DARHTTPException`.
"""
[docs] def create_inference_request(
self,
model_name: str,
objects: List[dict],
top_n: int = TOP_N,
retry: bool = True,
) -> dict:
"""
Performs inference for the given *objects* with *model_name*.
For each object in *objects*, returns the *topN* best predictions.
The *retry* parameter determines whether to retry on HTTP errors indicated by
the remote API endpoint or for other connection problems. See :ref:`retry` for
trade-offs involved here.
.. note::
This endpoint called by this method has a limit of *LIMIT_OBJECTS_PER_CALL*
on the number of *objects*. See :meth:`do_bulk_inference` to circumvent
this limit.
.. versionchanged:: 0.13.0
The *retry* parameter now defaults to true. This increases reliability of the
call. See corresponding note on :meth:`do_bulk_inference`.
:param model_name: name of the model used for inference
:param objects: Objects to be classified
:param top_n: How many predictions to return per object
:param retry: whether to retry on errors. Default: True
:return: API response
"""
self.log.debug(
"Submitting Inference request for model '%s' with '%s'"
" objects and top_n '%s' ",
model_name,
len(objects),
top_n,
)
endpoint = InferencePaths.format_inference_endpoint_by_name(model_name)
response = self.session.post_to_endpoint(
endpoint, payload={"topN": top_n, "objects": objects}, retry=retry
)
as_json = response.json()
self.log.debug("Inference response ID: %s", as_json["id"])
return as_json
[docs] def do_bulk_inference(
self,
model_name: str,
objects: List[dict],
top_n: int = TOP_N,
retry: bool = True,
worker_count: int = 4,
) -> List[Union[dict, None]]:
"""
Performs bulk inference for larger collections.
For *objects* collections larger than *LIMIT_OBJECTS_PER_CALL*, splits
the data into several smaller Inference requests.
Requests are executed in parallel.
Returns the aggregated values of the *predictions* of the original API response
as returned by :meth:`create_inference_request`. If one of the inference
requests to the service fails, an artificial prediction object is inserted with
the `labels` key set to `None` for each of the objects in the failing request.
Example of a prediction object which indicates an error:
.. code-block:: python
{
'objectId': 'b5cbcb34-7ab9-4da5-b7ec-654c90757eb9',
'labels': None,
'_sdk_error': 'RequestException: Request Error'
}
In case the `objects` passed to this method do not contain the `objectId` field,
the value is set to `None` in the error prediction object:
.. code-block:: python
{
'objectId': None,
'labels': None,
'_sdk_error': 'RequestException: Request Error'
}
.. note::
This method calls the inference endpoint multiple times to process all data.
For non-trial service instances, each call will incur a cost.
To reduce the impact of a failed request, this method will retry failed
requests.
There is a small chance that even retried requests will be charged, e.g.
if a problem occurs with the request on the client side outside the
control of the service and after the service has processed the request.
To disable `retry` behavior simply pass `retry=False` to the method.
Typically, the default behavior of `retry=True` is safe and improves
reliability of bulk inference greatly.
.. versionchanged:: 0.7.0
The default for the `retry` parameter changed from `retry=False` to
`retry=True` for increased reliability in day-to-day operations.
.. versionchanged:: 0.12.0
Requests are now executed in parallel with up to four threads.
Errors are now handled in this method instead of raising an exception and
discarding inference results from previous requests. For objects where the
inference request did not succeed, a replacement `dict` object is placed in
the returned `list`.
This `dict` follows the format of the `ObjectPrediction` object sent by the
service. To indicate that this is a client-side generated placeholder, the
`labels` key for all ObjectPrediction dicts of the failed inference request
has value `None`.
A `_sdk_error` key is added with the Exception details.
.. versionadded:: 0.12.0
The `worker_count` parameter allows to fine-tune the number of concurrent
request threads. Set `worker_count` to `1` to disable concurrent execution of
requests.
:param model_name: name of the model used for inference
:param objects: Objects to be classified
:param top_n: How many predictions to return per object
:param retry: whether to retry on errors. Default: True
:param worker_count: maximum number of concurrent requests
:raises: InvalidWorkerCount if worker_count param is incorrect
:return: the aggregated ObjectPrediction dictionaries
"""
if worker_count is None:
raise InvalidWorkerCount("worker_count cannot be None!")
if worker_count > 4:
msg = "worker_count too high: %s. Up to 4 allowed." % worker_count
raise InvalidWorkerCount(msg)
if worker_count <= 0:
msg = "worker_count must be greater than 0!"
raise InvalidWorkerCount(msg)
def predict_call(work_package):
try:
response = self.create_inference_request(
model_name, work_package, top_n=top_n, retry=retry
)
return response["predictions"]
except (DARHTTPException, RequestException) as exc:
self.log.warning(
"Caught %s during bulk inference. "
"Setting results to None for this batch!",
exc,
exc_info=True,
)
prediction_error = [
{
"objectId": inference_object.get("objectId", None),
"labels": None,
"_sdk_error": "{}: {}".format(exc.__class__.__name__, str(exc)),
}
for inference_object in work_package
]
return prediction_error
results = []
with ThreadPoolExecutor(max_workers=worker_count) as pool:
results_iterator = pool.map(
predict_call, split_list(objects, LIMIT_OBJECTS_PER_CALL)
)
for predictions in results_iterator:
results.extend(predictions)
return results
[docs] def create_inference_request_with_url(
self,
url: str,
objects: List[dict],
top_n: int = TOP_N,
retry: bool = True,
) -> dict:
"""
Performs inference for the given *objects* against fully-qualified URL.
A complete inference URL can be the passed to the method inference, instead
of constructing URL from using base url and model name
.. versionchanged:: 0.13.0
The *retry* parameter now defaults to true. This increases reliability of the
call. See corresponding note on :meth:`do_bulk_inference`.
:param url: fully-qualified inference URL
:param objects: Objects to be classified
:param top_n: How many predictions to return per object
:param retry: whether to retry on errors. Default: True
:return: API response
"""
self.log.debug(
"Submitting Inference request with '%s'"
" objects and top_n '%s' to url %s",
len(objects),
top_n,
url,
)
response = self.session.post_to_url(
url, payload={"topN": top_n, "objects": objects}, retry=retry
)
as_json = response.json()
self.log.debug("Inference response ID: %s", as_json["id"])
return as_json