Напиши простую нейросеть для генерации текста на

Напиши простую нейросеть для генерации текста на основе промта и обучающих данных на LTSM
Дополнительные параметры: codeFlowType: codeWrite codeFlowLang: C#
Создано: 22.08.2023 23:43 Обновлено: 1 год назад

Напиши простую нейросеть для генерации текста на основе промта и обучающих данных на LTSM

using System;
using System.Collections.Generic;
using System.Linq;
using Keras;
using Keras.Models;
using Keras.PreProcessing.Sequence;
using Numpy;
using Python.Runtime;

namespace TextGeneration
{
    class Program
    {
        static void Main(string[] args)
        {
            // Пример обучающих данных
            string[] trainingData = new string[]
            {
                "The cat sat on the mat",
                "The dog barked",
                "The mouse ran away",
                "The bird sang"
            };

            // Создание словаря символов
            var vocab = GetVocabulary(trainingData);

            // Создание последовательностей входных и выходных данных для обучения
            List<string> inputSequences = new List<string>();
            List<string> targetSequences = new List<string>();
            foreach (var text in trainingData)
            {
                for (int i = 0; i < text.Length - 1; i++)
                {
                    inputSequences.Add(text.Substring(0, i + 1));
                    targetSequences.Add(text[i + 1].ToString());
                }
            }

            // Преобразование последовательностей символов в числовые данные
            var inputSequenceData = ConvertSequencesToData(inputSequences, vocab);
            var targetSequenceData = ConvertSequencesToData(targetSequences, vocab);

            // Создание и обучение модели LSTM
            var model = CreateModel(vocab.Count);
            TrainModel(model, inputSequenceData, targetSequenceData);

            // Промпт для генерации текста
            string prompt = "The";

            // Генерация текста на основе промпта
            string generatedText = GenerateText(model, prompt, vocab);

            Console.WriteLine(generatedText);
        }

        static Dictionary<char, int> GetVocabulary(string[] data)
        {
            var vocab = new Dictionary<char, int>();
            int index = 0;

            foreach (var text in data)
            {
                foreach (var character in text)
                {
                    if (!vocab.ContainsKey(character))
                    {
                        vocab.Add(character, index);
                        index++;
                    }
                }
            }

            return vocab;
        }

        static NDarray ConvertSequencesToData(List<string> sequences, Dictionary<char, int> vocab)
        {
            int sequenceLength = sequences[0].Length;
            int numSequences = sequences.Count;
            int vocabSize = vocab.Count;

            var data = np.zeros(new Shape(numSequences, sequenceLength, vocabSize));
            for (int i = 0; i < numSequences; i++)
            {
                for (int j = 0; j < sequenceLength; j++)
                {
                    char character = sequences[i][j];
                    int vocabIndex = vocab[character];
                    data[i, j, vocabIndex] = 1;
                }
            }

            return data;
        }

        static Sequential CreateModel(int vocabSize)
        {
            var model = new Sequential();
            model.Add(new LSTM(128, input_shape: new Shape(1, vocabSize)));
            model.Add(new Dense(vocabSize, activation: "softmax"));
            model.Compile(loss: "categorical_crossentropy", optimizer: "adam");

            return model;
        }

        static void TrainModel(Sequential model, NDarray input, NDarray target)
        {
            model.Fit(input, target, batch_size: 16, epochs: 100, verbose: 0);
        }

        static string GenerateText(Sequential model, string prompt, Dictionary<char, int> vocab)
        {
            var text = prompt;
            int sequenceLength = prompt.Length;
            int vocabSize = vocab.Count;

            using (Py.GIL())
            {
                dynamic py = Py.Import("numpy");
                var generatedText = np.zeros(new Shape(1, sequenceLength, vocabSize));
                for (int i = 0; i < sequenceLength; i++)
                {
                    char character = prompt[i];
                    int vocabIndex = vocab[character];
                    generatedText[0, i, vocabIndex] = 1;
                }

                for (int i = 0; i < 50; i++) // Генерация 50 символов
                {
                    var predictions = model.Predict(generatedText)[0];
                    int predictedIndex = np.random.choice(predictions.argmax());
                    char predictedCharacter = vocab.FirstOrDefault(x => x.Value == predictedIndex).Key;
                    text += predictedCharacter;
                    generatedText = np.concatenate((generatedText[:, 1:, :], py.np.zeros((1, 1, vocabSize))));
                    generatedText[0, sequenceLength - 1, predictedIndex] = 1;
                }
            }

            return text;
        }
    }
}

Используется библиотека Keras для создания и обучения модели LTSM. Она позволяет создавать слои нейронной сети и компилировать модель с функцией потерь и оптимизатором. В данном случае мы создаем модель LTSM с входным слоем размерностью (1, размер словаря), скрытым слоем LTSM размерностью 128 и выходным слоем с функцией активации softmax. Затем модель обучается на входных и выходных данных с использованием функции Fit.

Для генерации текста модель используется последовательно. Перед генерацией текста промпт преобразуется в числовые данные в формате one-hot encoding. Затем модель делает предсказания для каждого символа и выбирает символ с наибольшей вероятностью. Для получения случайного символа с максимальной вероятностью используется функция np.random.choice.

Пример использования:

string[] trainingData = new string[]
{
    "The cat sat on the mat",
    "The dog barked",
    "The mouse ran away",
    "The bird sang"
};

var vocab = GetVocabulary(trainingData);

List<string> inputSequences = new List<string>();
List<string> targetSequences = new List<string>();
foreach (var text in trainingData)
{
    for (int i = 0; i < text.Length - 1; i++)
    {
        inputSequences.Add(text.Substring(0, i + 1));
        targetSequences.Add(text[i + 1].ToString());
    }
}

var inputSequenceData = ConvertSequencesToData(inputSequences, vocab);
var targetSequenceData = ConvertSequencesToData(targetSequences, vocab);

var model = CreateModel(vocab.Count);
TrainModel(model, inputSequenceData, targetSequenceData);

string prompt = "The";

string generatedText = GenerateText(model, prompt, vocab);

Console.WriteLine(generatedText);

В данном примере обучающие данные состоят из нескольких предложений, словарь символов определяется на основе этих данных. Затем создается и обучается модель LTSM. И наконец, используя промпт "The", генерируется текст.