diff --git a/Dockerfile b/Dockerfile index d5aa8a0..d1ade70 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,3 +13,8 @@ FROM python-ubuntu RUN apt-get update && apt-get install liblapack-dev libpng-dev libfreetype6-dev libqhull-dev git pkg-config portaudio19-dev swig libpulse-ocaml-dev gfortran libopenblas-dev libatlas-base-dev -y RUN pip3 install precise-lite + +COPY scripts/precise-convert-h5 /usr/local/bin/precise-convert-h5 +RUN chmod +x /usr/local/bin/precise-convert-h5 + +COPY scripts/convert_h5.py /usr/local/lib/python3.6/dist-packages/precise_lite/scripts/convert_h5.py \ No newline at end of file diff --git a/scripts/convert_h5.py b/scripts/convert_h5.py new file mode 100644 index 0000000..cada14f --- /dev/null +++ b/scripts/convert_h5.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# Copyright 2019 Mycroft AI Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from prettyparse import Usage + +from precise_lite.model import create_model, ModelParams +from precise_lite.scripts.train import BaseScript + + +class ConvertH5(BaseScript): + usage = Usage(''' + Convert a model to h5 format + + :model str + Keras model file (.net) to load from + + :-s --sensitivity float 0.2 + Weighted loss bias. Higher values decrease increase positives + + :-nv --no-validation + Disable accuracy and validation calculation + to improve speed during training + + :-em --extra-metrics + Add extra metrics during training + ... + ''') | BaseScript.usage + + def __init__(self, args): + super().__init__(args) + + params = ModelParams( + skip_acc=self.args.no_validation, extra_metrics=self.args.extra_metrics, + loss_bias=1.0 - self.args.sensitivity + ) + self.model = create_model(self.args.model, params) + + def run(self): + print('Converting model to h5') + + self.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format + +main = ConvertH5.run_main + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/scripts/precise-convert-h5 b/scripts/precise-convert-h5 new file mode 100644 index 0000000..439304b --- /dev/null +++ b/scripts/precise-convert-h5 @@ -0,0 +1,8 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +import re +import sys +from precise_lite.scripts.convert_h5 import main +if __name__ == '__main__': + sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) + sys.exit(main())