Source code for malaysia_ai_projects.pembalakan

import logging
import numpy as np
from herpetologist import check_type
from malaya_boilerplate.frozen_graph import (
    generate_session,
    nodes_session,
    load_graph,
)
from malaya_boilerplate.execute import execute_graph
from skimage.transform import resize
from typing import List

_available_models = {
    'efficientnet-b4': {
        'Size (MB)': 79.9,
        'Test Loss': 0.08283,
    },
    'efficientnet-b4-quantized': {
        'Size (MB)': 20.7,
        'Test Loss': 0.08283,
    },
    'efficientnet-b2': {
        'Size (MB)': 66.4,
        'Test Loss': 0.09731,
    },
    'efficientnet-b2-quantized': {
        'Size (MB)': 17.1,
        'Test Loss': 0.09731,
    },
}

repo_id = 'malay-huggingface/pembalakan'
huggingface_filenames = {
    'efficientnet-b4': 'efficientnet-b4/frozen_model.pb',
    'efficientnet-b4-quantized': 'efficientnet-b4/frozen_model.pb.quantized',
    'efficientnet-b2': 'efficientnet-b2/frozen_model.pb',
    'efficientnet-b2-quantized': 'efficientnet-b2/frozen_model.pb.quantized'
}


[docs]def available_model(): """ List available Pembalakan models. """ from malaysia_ai_projects.utils import describe_availability return describe_availability(_available_models)
[docs]@check_type def load(model: str = 'efficientnet-b2', **kwargs): """ Load Pembalakan model. Parameters ---------- model : str, optional (default='efficientnet-b2') Model architecture supported. Allowed values: * ``'efficientnet-b4'`` - EfficientNet B4 + Unet. * ``'efficientnet-b4-quantized'`` - EfficientNet B4 + Unet with dynamic quantized. * ``'efficientnet-b2'`` - EfficientNet B2 + Unet. * ``'efficientnet-b2-quantized'`` - EfficientNet B2 + Unet with dynamic quantized. Returns ------- result : malaysia_ai_projects.pembalakan.Model class """ model = model.lower() if model not in _available_models: raise ValueError( 'model not supported, please check supported models from `malaysia_ai_projects.pembalakan.available_model()`.' ) from huggingface_hub import hf_hub_download model = hf_hub_download(repo_id=repo_id, filename=huggingface_filenames[model]) g = load_graph(package=None, frozen_graph_filename=model, **kwargs) input_nodes, output_nodes = nodes_session(g, ['input'], ['logits']) return Model( input_nodes=input_nodes, output_nodes=output_nodes, sess=generate_session(graph=g, **kwargs) )
[docs]class Model: def __init__(self, input_nodes, output_nodes, sess): self._input_nodes = input_nodes self._output_nodes = output_nodes self._sess = sess self._size = 256 def _execute(self, inputs, input_labels, output_labels): return execute_graph( inputs=inputs, input_labels=input_labels, output_labels=output_labels, sess=self._sess, input_nodes=self._input_nodes, output_nodes=self._output_nodes, )
[docs] def predict(self, inputs: List[np.array]): """ Parameters ---------- input: List[np.array] List of np.array, should be size [H, W, 3], `H` and `W` can be dynamic. Returns ------- result: List[np.array] """ sizes, batch = [], [] for input in inputs: sizes.append(input.shape[:-1]) batch.append(resize(input, (self._size, self._size), anti_aliasing=False)) r = self._execute( inputs=[batch], input_labels=['input'], output_labels=['logits'], ) v = r['logits'] outputs = [] for no, output in enumerate(v): outputs.append(np.around(resize(output, sizes[no], anti_aliasing=False))) return outputs