diff --git a/Dockerfile b/Dockerfile index 569cd5b..0d5be3b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,8 +16,7 @@ RUN mkdir /training && pip3 install precise-lite WORKDIR /training -COPY scripts/precise-lite-convert-h5 /usr/local/bin/precise-lite-convert-h5 -COPY scripts/convert_h5.py /usr/local/lib/python3.6/dist-packages/precise_lite/scripts/convert_h5.py +COPY scripts/train_incremental.py /usr/local/lib/python3.6/dist-packages/precise_lite/scripts/train_incremental.py COPY entrypoint.sh /entrypoint.sh ENTRYPOINT ["/entrypoint.sh"] diff --git a/scripts/convert_h5.py b/scripts/convert_h5.py deleted file mode 100644 index cada14f..0000000 --- a/scripts/convert_h5.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2019 Mycroft AI Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from prettyparse import Usage - -from precise_lite.model import create_model, ModelParams -from precise_lite.scripts.train import BaseScript - - -class ConvertH5(BaseScript): - usage = Usage(''' - Convert a model to h5 format - - :model str - Keras model file (.net) to load from - - :-s --sensitivity float 0.2 - Weighted loss bias. Higher values decrease increase positives - - :-nv --no-validation - Disable accuracy and validation calculation - to improve speed during training - - :-em --extra-metrics - Add extra metrics during training - ... - ''') | BaseScript.usage - - def __init__(self, args): - super().__init__(args) - - params = ModelParams( - skip_acc=self.args.no_validation, extra_metrics=self.args.extra_metrics, - loss_bias=1.0 - self.args.sensitivity - ) - self.model = create_model(self.args.model, params) - - def run(self): - print('Converting model to h5') - - self.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format - -main = ConvertH5.run_main - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/scripts/precise-lite-convert-h5 b/scripts/precise-lite-convert-h5 deleted file mode 100755 index 439304b..0000000 --- a/scripts/precise-lite-convert-h5 +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import re -import sys -from precise_lite.scripts.convert_h5 import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/scripts/train_incremental.py b/scripts/train_incremental.py new file mode 100644 index 0000000..2997ecd --- /dev/null +++ b/scripts/train_incremental.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +# Copyright 2019 Mycroft AI Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from os import makedirs, rename +from os.path import basename, splitext, isfile, join +from prettyparse import Usage +from random import random +from typing import * + +from precise_lite.model import create_model, ModelParams +from precise_lite.network_runner import Listener, KerasRunner +from precise_lite.params import pr +from precise_lite.scripts.train import TrainScript +from precise_lite.train_data import TrainData +from precise_lite.util import load_audio, save_audio, glob_all, chunk_audio + + +def load_trained_fns(model_name: str) -> list: + progress_file = model_name.replace('.net', '') + '.trained.txt' + if isfile(progress_file): + print('Starting from saved position in', progress_file) + with open(progress_file, 'rb') as f: + return f.read().decode('utf8', 'surrogatepass').split('\n') + return [] + + +def save_trained_fns(trained_fns: list, model_name: str): + with open(model_name.replace('.net', '') + '.trained.txt', 'wb') as f: + f.write('\n'.join(trained_fns).encode('utf8', 'surrogatepass')) + + +class TrainIncrementalScript(TrainScript): + usage = Usage(''' + Train a model to inhibit activation by + marking false activations and retraining + + :-e --epochs int 1 + Number of epochs to train before continuing evaluation + + :-ds --delay-samples int 10 + Number of false activations to save before re-training + + :-c --chunk-size int 2048 + Number of samples between testing the neural network + + :-r --random-data-folder str data/random + Folder with properly encoded wav files of + random audio that should not cause an activation + + :-th --threshold float 0.5 + Network output to be considered activated + + ... + ''') | TrainScript.usage + + def __init__(self, args): + super().__init__(args) + + for i in ( + join(self.args.folder, 'not-wake-word', 'generated'), + join(self.args.folder, 'test', 'not-wake-word', 'generated') + ): + makedirs(i, exist_ok=True) + + self.trained_fns = load_trained_fns(self.args.model) + self.audio_buffer = np.zeros(pr.buffer_samples, dtype=float) + + params = ModelParams( + skip_acc=self.args.no_validation, extra_metrics=self.args.extra_metrics, + loss_bias=1.0 - self.args.sensitivity + ) + model = create_model(self.args.model, params) + self.listener = Listener(self.args.model, self.args.chunk_size, runner_cls=KerasRunner) + self.listener.runner = KerasRunner(self.args.model) + self.listener.runner.model = model + self.samples_since_train = 0 + + @staticmethod + def load_data(args: Any): + data = TrainData.from_tags(args.tags_file, args.tags_folder) + return data.load(True, not args.no_validation) + + def retrain(self): + """Train for a session, pulling in any new data from the filesystem""" + folder = TrainData.from_folder(self.args.folder) + train_data, test_data = folder.load(True, not self.args.no_validation) + + train_data = TrainData.merge(train_data, self.sampled_data) + test_data = TrainData.merge(test_data, self.test) + train_inputs, train_outputs = train_data + print() + try: + self.listener.runner.model.fit( + train_inputs, train_outputs, self.args.batch_size, self.epoch + self.args.epochs, + validation_data=test_data, callbacks=self.callbacks, initial_epoch=self.epoch + ) + finally: + self.listener.runner.model.save(self.args.model + '/') + + def train_on_audio(self, fn: str): + """Run through a single audio file""" + save_test = random() > 0.8 + audio = load_audio(fn) + num_chunks = len(audio) // self.args.chunk_size + + self.listener.clear() + + for i, chunk in enumerate(chunk_audio(audio, self.args.chunk_size)): + print('\r' + str(i * 100. / num_chunks) + '%', end='', flush=True) + self.audio_buffer = np.concatenate((self.audio_buffer[len(chunk):], chunk)) + conf = self.listener.update(chunk) + if conf > self.args.threshold: + self.samples_since_train += 1 + name = splitext(basename(fn))[0] + '-' + str(i) + '.wav' + name = join(self.args.folder, 'test' if save_test else '', 'not-wake-word', + 'generated', name) + save_audio(name, self.audio_buffer) + print() + print('Saved to:', name) + + if not save_test and self.samples_since_train >= self.args.delay_samples and \ + self.args.epochs > 0: + self.samples_since_train = 0 + self.retrain() + + def run(self): + """ + Begin reading through audio files, saving false + activations and retraining when necessary + """ + for fn in glob_all(self.args.random_data_folder, '*.wav'): + if fn in self.trained_fns: + print('Skipping ' + fn + '...') + continue + + print('Starting file ' + fn + '...') + self.train_on_audio(fn) + print('\r100% ') + + self.trained_fns.append(fn) + save_trained_fns(self.trained_fns, self.args.model) + + +main = TrainIncrementalScript.run_main + +if __name__ == '__main__': + main() \ No newline at end of file