chromium/third_party/protobuf/csharp/src/Google.Protobuf/ParsingPrimitivesMessages.cs

#region Copyright notice and license
// Protocol Buffers - Google's data interchange format
// Copyright 2008 Google Inc.  All rights reserved.
// https://developers.google.com/protocol-buffers/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
//     * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//     * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
//     * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#endregion

using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Runtime.CompilerServices;
using System.Security;
using Google.Protobuf.Collections;

namespace Google.Protobuf
{
    /// <summary>
    /// Reading and skipping messages / groups
    /// </summary>
    [SecuritySafeCritical]
    internal static class ParsingPrimitivesMessages
    {
        private static readonly byte[] ZeroLengthMessageStreamData = new byte[] { 0 };

        public static void SkipLastField(ref ReadOnlySpan<byte> buffer, ref ParserInternalState state)
        {
            if (state.lastTag == 0)
            {
                throw new InvalidOperationException("SkipLastField cannot be called at the end of a stream");
            }
            switch (WireFormat.GetTagWireType(state.lastTag))
            {
                case WireFormat.WireType.StartGroup:
                    SkipGroup(ref buffer, ref state, state.lastTag);
                    break;
                case WireFormat.WireType.EndGroup:
                    throw new InvalidProtocolBufferException(
                        "SkipLastField called on an end-group tag, indicating that the corresponding start-group was missing");
                case WireFormat.WireType.Fixed32:
                    ParsingPrimitives.ParseRawLittleEndian32(ref buffer, ref state);
                    break;
                case WireFormat.WireType.Fixed64:
                    ParsingPrimitives.ParseRawLittleEndian64(ref buffer, ref state);
                    break;
                case WireFormat.WireType.LengthDelimited:
                    var length = ParsingPrimitives.ParseLength(ref buffer, ref state);
                    ParsingPrimitives.SkipRawBytes(ref buffer, ref state, length);
                    break;
                case WireFormat.WireType.Varint:
                    ParsingPrimitives.ParseRawVarint32(ref buffer, ref state);
                    break;
            }
        }

        /// <summary>
        /// Skip a group.
        /// </summary>
        public static void SkipGroup(ref ReadOnlySpan<byte> buffer, ref ParserInternalState state, uint startGroupTag)
        {
            // Note: Currently we expect this to be the way that groups are read. We could put the recursion
            // depth changes into the ReadTag method instead, potentially...
            state.recursionDepth++;
            if (state.recursionDepth >= state.recursionLimit)
            {
                throw InvalidProtocolBufferException.RecursionLimitExceeded();
            }
            uint tag;
            while (true)
            {
                tag = ParsingPrimitives.ParseTag(ref buffer, ref state);
                if (tag == 0)
                {
                    throw InvalidProtocolBufferException.TruncatedMessage();
                }
                // Can't call SkipLastField for this case- that would throw.
                if (WireFormat.GetTagWireType(tag) == WireFormat.WireType.EndGroup)
                {
                    break;
                }
                // This recursion will allow us to handle nested groups.
                SkipLastField(ref buffer, ref state);
            }
            int startField = WireFormat.GetTagFieldNumber(startGroupTag);
            int endField = WireFormat.GetTagFieldNumber(tag);
            if (startField != endField)
            {
                throw new InvalidProtocolBufferException(
                    $"Mismatched end-group tag. Started with field {startField}; ended with field {endField}");
            }
            state.recursionDepth--;
        }

        public static void ReadMessage(ref ParseContext ctx, IMessage message)
        {
            int length = ParsingPrimitives.ParseLength(ref ctx.buffer, ref ctx.state);
            if (ctx.state.recursionDepth >= ctx.state.recursionLimit)
            {
                throw InvalidProtocolBufferException.RecursionLimitExceeded();
            }
            int oldLimit = SegmentedBufferHelper.PushLimit(ref ctx.state, length);
            ++ctx.state.recursionDepth;

            ReadRawMessage(ref ctx, message);

            CheckReadEndOfStreamTag(ref ctx.state);
            // Check that we've read exactly as much data as expected.
            if (!SegmentedBufferHelper.IsReachedLimit(ref ctx.state))
            {
                throw InvalidProtocolBufferException.TruncatedMessage();
            }
            --ctx.state.recursionDepth;
            SegmentedBufferHelper.PopLimit(ref ctx.state, oldLimit);
        }

        public static KeyValuePair<TKey, TValue> ReadMapEntry<TKey, TValue>(ref ParseContext ctx, MapField<TKey, TValue>.Codec codec)
        {
            int length = ParsingPrimitives.ParseLength(ref ctx.buffer, ref ctx.state);
            if (ctx.state.recursionDepth >= ctx.state.recursionLimit)
            {
                throw InvalidProtocolBufferException.RecursionLimitExceeded();
            }
            int oldLimit = SegmentedBufferHelper.PushLimit(ref ctx.state, length);
            ++ctx.state.recursionDepth;

            TKey key = codec.KeyCodec.DefaultValue;
            TValue value = codec.ValueCodec.DefaultValue;

            uint tag;
            while ((tag = ctx.ReadTag()) != 0)
            {
                if (tag == codec.KeyCodec.Tag)
                {
                    key = codec.KeyCodec.Read(ref ctx);
                }
                else if (tag == codec.ValueCodec.Tag)
                {
                    value = codec.ValueCodec.Read(ref ctx);
                }
                else
                {
                    SkipLastField(ref ctx.buffer, ref ctx.state);
                }
            }

            // Corner case: a map entry with a key but no value, where the value type is a message.
            // Read it as if we'd seen input with no data (i.e. create a "default" message).
            if (value == null)
            {
                if (ctx.state.CodedInputStream != null)
                {
                    // the decoded message might not support parsing from ParseContext, so
                    // we need to allow fallback to the legacy MergeFrom(CodedInputStream) parsing.
                    value = codec.ValueCodec.Read(new CodedInputStream(ZeroLengthMessageStreamData));
                }
                else
                {
                    ParseContext.Initialize(new ReadOnlySequence<byte>(ZeroLengthMessageStreamData), out ParseContext zeroLengthCtx);
                    value = codec.ValueCodec.Read(ref zeroLengthCtx);
                }
            }

            CheckReadEndOfStreamTag(ref ctx.state);
            // Check that we've read exactly as much data as expected.
            if (!SegmentedBufferHelper.IsReachedLimit(ref ctx.state))
            {
                throw InvalidProtocolBufferException.TruncatedMessage();
            }
            --ctx.state.recursionDepth;
            SegmentedBufferHelper.PopLimit(ref ctx.state, oldLimit);

            return new KeyValuePair<TKey, TValue>(key, value);
        }

        public static void ReadGroup(ref ParseContext ctx, IMessage message)
        {
            if (ctx.state.recursionDepth >= ctx.state.recursionLimit)
            {
                throw InvalidProtocolBufferException.RecursionLimitExceeded();
            }
            ++ctx.state.recursionDepth;
            
            uint tag = ctx.state.lastTag;
            int fieldNumber = WireFormat.GetTagFieldNumber(tag);
            ReadRawMessage(ref ctx, message);
            CheckLastTagWas(ref ctx.state, WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));

            --ctx.state.recursionDepth;
        }

        public static void ReadGroup(ref ParseContext ctx, int fieldNumber, UnknownFieldSet set)
        {
            if (ctx.state.recursionDepth >= ctx.state.recursionLimit)
            {
                throw InvalidProtocolBufferException.RecursionLimitExceeded();
            }
            ++ctx.state.recursionDepth;

            set.MergeGroupFrom(ref ctx);
            CheckLastTagWas(ref ctx.state, WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));

            --ctx.state.recursionDepth;
        }

        public static void ReadRawMessage(ref ParseContext ctx, IMessage message)
        {
            if (message is IBufferMessage bufferMessage)
            {
                bufferMessage.InternalMergeFrom(ref ctx);   
            }
            else
            {
                // If we reached here, it means we've ran into a nested message with older generated code
                // which doesn't provide the InternalMergeFrom method that takes a ParseContext.
                // With a slight performance overhead, we can still parse this message just fine,
                // but we need to find the original CodedInputStream instance that initiated this
                // parsing process and make sure its internal state is up to date.
                // Note that this performance overhead is not very high (basically copying contents of a struct)
                // and it will only be incurred in case the application mixes older and newer generated code.
                // Regenerating the code from .proto files will remove this overhead because it will
                // generate the InternalMergeFrom method we need.

                if (ctx.state.CodedInputStream == null)
                {
                    // This can only happen when the parsing started without providing a CodedInputStream instance
                    // (e.g. ParseContext was created directly from a ReadOnlySequence).
                    // That also means that one of the new parsing APIs was used at the top level
                    // and in such case it is reasonable to require that all the nested message provide
                    // up-to-date generated code with ParseContext support (and fail otherwise).
                    throw new InvalidProtocolBufferException($"Message {message.GetType().Name} doesn't provide the generated method that enables ParseContext-based parsing. You might need to regenerate the generated protobuf code.");
                }

                ctx.CopyStateTo(ctx.state.CodedInputStream);
                try
                {
                    // fallback parse using the CodedInputStream that started current parsing tree
                    message.MergeFrom(ctx.state.CodedInputStream);
                }
                finally
                {
                    ctx.LoadStateFrom(ctx.state.CodedInputStream);
                }
            }
        }

        /// <summary>
        /// Verifies that the last call to ReadTag() returned tag 0 - in other words,
        /// we've reached the end of the stream when we expected to.
        /// </summary>
        /// <exception cref="InvalidProtocolBufferException">The 
        /// tag read was not the one specified</exception>
        public static void CheckReadEndOfStreamTag(ref ParserInternalState state)
        {
            if (state.lastTag != 0)
            {
                throw InvalidProtocolBufferException.MoreDataAvailable();
            }
        }

        private static void CheckLastTagWas(ref ParserInternalState state, uint expectedTag)
        {
            if (state.lastTag != expectedTag) {
               throw InvalidProtocolBufferException.InvalidEndTag();
            }
        }
    }
}