Change train_incremental to not use h5
continuous-integration/drone/push Build is passing
Details
continuous-integration/drone/push Build is passing
Details
This commit is contained in:
parent
21e13c00e3
commit
cc948a1ad6
|
@ -16,8 +16,7 @@ RUN mkdir /training && pip3 install precise-lite
|
||||||
|
|
||||||
WORKDIR /training
|
WORKDIR /training
|
||||||
|
|
||||||
COPY scripts/precise-lite-convert-h5 /usr/local/bin/precise-lite-convert-h5
|
COPY scripts/train_incremental.py /usr/local/lib/python3.6/dist-packages/precise_lite/scripts/train_incremental.py
|
||||||
COPY scripts/convert_h5.py /usr/local/lib/python3.6/dist-packages/precise_lite/scripts/convert_h5.py
|
|
||||||
COPY entrypoint.sh /entrypoint.sh
|
COPY entrypoint.sh /entrypoint.sh
|
||||||
|
|
||||||
ENTRYPOINT ["/entrypoint.sh"]
|
ENTRYPOINT ["/entrypoint.sh"]
|
||||||
|
|
|
@ -1,57 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2019 Mycroft AI Inc.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from prettyparse import Usage
|
|
||||||
|
|
||||||
from precise_lite.model import create_model, ModelParams
|
|
||||||
from precise_lite.scripts.train import BaseScript
|
|
||||||
|
|
||||||
|
|
||||||
class ConvertH5(BaseScript):
|
|
||||||
usage = Usage('''
|
|
||||||
Convert a model to h5 format
|
|
||||||
|
|
||||||
:model str
|
|
||||||
Keras model file (.net) to load from
|
|
||||||
|
|
||||||
:-s --sensitivity float 0.2
|
|
||||||
Weighted loss bias. Higher values decrease increase positives
|
|
||||||
|
|
||||||
:-nv --no-validation
|
|
||||||
Disable accuracy and validation calculation
|
|
||||||
to improve speed during training
|
|
||||||
|
|
||||||
:-em --extra-metrics
|
|
||||||
Add extra metrics during training
|
|
||||||
...
|
|
||||||
''') | BaseScript.usage
|
|
||||||
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__(args)
|
|
||||||
|
|
||||||
params = ModelParams(
|
|
||||||
skip_acc=self.args.no_validation, extra_metrics=self.args.extra_metrics,
|
|
||||||
loss_bias=1.0 - self.args.sensitivity
|
|
||||||
)
|
|
||||||
self.model = create_model(self.args.model, params)
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
print('Converting model to h5')
|
|
||||||
|
|
||||||
self.model.save(self.args.model + '.h5') # Save with '.h5' file extension to force format
|
|
||||||
|
|
||||||
main = ConvertH5.run_main
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
|
@ -1,8 +0,0 @@
|
||||||
#!/usr/bin/python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from precise_lite.scripts.convert_h5 import main
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
|
||||||
sys.exit(main())
|
|
|
@ -0,0 +1,159 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2019 Mycroft AI Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import numpy as np
|
||||||
|
from os import makedirs, rename
|
||||||
|
from os.path import basename, splitext, isfile, join
|
||||||
|
from prettyparse import Usage
|
||||||
|
from random import random
|
||||||
|
from typing import *
|
||||||
|
|
||||||
|
from precise_lite.model import create_model, ModelParams
|
||||||
|
from precise_lite.network_runner import Listener, KerasRunner
|
||||||
|
from precise_lite.params import pr
|
||||||
|
from precise_lite.scripts.train import TrainScript
|
||||||
|
from precise_lite.train_data import TrainData
|
||||||
|
from precise_lite.util import load_audio, save_audio, glob_all, chunk_audio
|
||||||
|
|
||||||
|
|
||||||
|
def load_trained_fns(model_name: str) -> list:
|
||||||
|
progress_file = model_name.replace('.net', '') + '.trained.txt'
|
||||||
|
if isfile(progress_file):
|
||||||
|
print('Starting from saved position in', progress_file)
|
||||||
|
with open(progress_file, 'rb') as f:
|
||||||
|
return f.read().decode('utf8', 'surrogatepass').split('\n')
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def save_trained_fns(trained_fns: list, model_name: str):
|
||||||
|
with open(model_name.replace('.net', '') + '.trained.txt', 'wb') as f:
|
||||||
|
f.write('\n'.join(trained_fns).encode('utf8', 'surrogatepass'))
|
||||||
|
|
||||||
|
|
||||||
|
class TrainIncrementalScript(TrainScript):
|
||||||
|
usage = Usage('''
|
||||||
|
Train a model to inhibit activation by
|
||||||
|
marking false activations and retraining
|
||||||
|
|
||||||
|
:-e --epochs int 1
|
||||||
|
Number of epochs to train before continuing evaluation
|
||||||
|
|
||||||
|
:-ds --delay-samples int 10
|
||||||
|
Number of false activations to save before re-training
|
||||||
|
|
||||||
|
:-c --chunk-size int 2048
|
||||||
|
Number of samples between testing the neural network
|
||||||
|
|
||||||
|
:-r --random-data-folder str data/random
|
||||||
|
Folder with properly encoded wav files of
|
||||||
|
random audio that should not cause an activation
|
||||||
|
|
||||||
|
:-th --threshold float 0.5
|
||||||
|
Network output to be considered activated
|
||||||
|
|
||||||
|
...
|
||||||
|
''') | TrainScript.usage
|
||||||
|
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__(args)
|
||||||
|
|
||||||
|
for i in (
|
||||||
|
join(self.args.folder, 'not-wake-word', 'generated'),
|
||||||
|
join(self.args.folder, 'test', 'not-wake-word', 'generated')
|
||||||
|
):
|
||||||
|
makedirs(i, exist_ok=True)
|
||||||
|
|
||||||
|
self.trained_fns = load_trained_fns(self.args.model)
|
||||||
|
self.audio_buffer = np.zeros(pr.buffer_samples, dtype=float)
|
||||||
|
|
||||||
|
params = ModelParams(
|
||||||
|
skip_acc=self.args.no_validation, extra_metrics=self.args.extra_metrics,
|
||||||
|
loss_bias=1.0 - self.args.sensitivity
|
||||||
|
)
|
||||||
|
model = create_model(self.args.model, params)
|
||||||
|
self.listener = Listener(self.args.model, self.args.chunk_size, runner_cls=KerasRunner)
|
||||||
|
self.listener.runner = KerasRunner(self.args.model)
|
||||||
|
self.listener.runner.model = model
|
||||||
|
self.samples_since_train = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_data(args: Any):
|
||||||
|
data = TrainData.from_tags(args.tags_file, args.tags_folder)
|
||||||
|
return data.load(True, not args.no_validation)
|
||||||
|
|
||||||
|
def retrain(self):
|
||||||
|
"""Train for a session, pulling in any new data from the filesystem"""
|
||||||
|
folder = TrainData.from_folder(self.args.folder)
|
||||||
|
train_data, test_data = folder.load(True, not self.args.no_validation)
|
||||||
|
|
||||||
|
train_data = TrainData.merge(train_data, self.sampled_data)
|
||||||
|
test_data = TrainData.merge(test_data, self.test)
|
||||||
|
train_inputs, train_outputs = train_data
|
||||||
|
print()
|
||||||
|
try:
|
||||||
|
self.listener.runner.model.fit(
|
||||||
|
train_inputs, train_outputs, self.args.batch_size, self.epoch + self.args.epochs,
|
||||||
|
validation_data=test_data, callbacks=self.callbacks, initial_epoch=self.epoch
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self.listener.runner.model.save(self.args.model + '/')
|
||||||
|
|
||||||
|
def train_on_audio(self, fn: str):
|
||||||
|
"""Run through a single audio file"""
|
||||||
|
save_test = random() > 0.8
|
||||||
|
audio = load_audio(fn)
|
||||||
|
num_chunks = len(audio) // self.args.chunk_size
|
||||||
|
|
||||||
|
self.listener.clear()
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunk_audio(audio, self.args.chunk_size)):
|
||||||
|
print('\r' + str(i * 100. / num_chunks) + '%', end='', flush=True)
|
||||||
|
self.audio_buffer = np.concatenate((self.audio_buffer[len(chunk):], chunk))
|
||||||
|
conf = self.listener.update(chunk)
|
||||||
|
if conf > self.args.threshold:
|
||||||
|
self.samples_since_train += 1
|
||||||
|
name = splitext(basename(fn))[0] + '-' + str(i) + '.wav'
|
||||||
|
name = join(self.args.folder, 'test' if save_test else '', 'not-wake-word',
|
||||||
|
'generated', name)
|
||||||
|
save_audio(name, self.audio_buffer)
|
||||||
|
print()
|
||||||
|
print('Saved to:', name)
|
||||||
|
|
||||||
|
if not save_test and self.samples_since_train >= self.args.delay_samples and \
|
||||||
|
self.args.epochs > 0:
|
||||||
|
self.samples_since_train = 0
|
||||||
|
self.retrain()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
Begin reading through audio files, saving false
|
||||||
|
activations and retraining when necessary
|
||||||
|
"""
|
||||||
|
for fn in glob_all(self.args.random_data_folder, '*.wav'):
|
||||||
|
if fn in self.trained_fns:
|
||||||
|
print('Skipping ' + fn + '...')
|
||||||
|
continue
|
||||||
|
|
||||||
|
print('Starting file ' + fn + '...')
|
||||||
|
self.train_on_audio(fn)
|
||||||
|
print('\r100% ')
|
||||||
|
|
||||||
|
self.trained_fns.append(fn)
|
||||||
|
save_trained_fns(self.trained_fns, self.args.model)
|
||||||
|
|
||||||
|
|
||||||
|
main = TrainIncrementalScript.run_main
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue